207 lines
5.8 KiB
Python
207 lines
5.8 KiB
Python
from emis_funky_funktions import *
|
|
from typing import *
|
|
|
|
from dataclasses import dataclass, field
|
|
from functools import reduce
|
|
|
|
MonoType: TypeAlias = 'FunctionTy | HoleTy | IntTy | VarTy'
|
|
Substitution: TypeAlias = 'tuple[MonoType, VarTy]'
|
|
|
|
@dataclass(frozen=True)
|
|
class UnificationError:
|
|
ty1: MonoType
|
|
ty2: MonoType
|
|
|
|
@dataclass(frozen=True)
|
|
class VarTy:
|
|
id: int
|
|
name: str = field(compare=False)
|
|
|
|
def subst(self, replacement: MonoType, var: 'VarTy') -> MonoType:
|
|
if var.id == self.id:
|
|
return replacement
|
|
else:
|
|
return self
|
|
|
|
def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]:
|
|
match other:
|
|
case VarTy(id, str) if id == self.id:
|
|
return Ok(tuple())
|
|
case _:
|
|
return Ok(((other, self),))
|
|
raise Exception('Unreachable')
|
|
|
|
def free_vars(self) -> FrozenSet['VarTy']:
|
|
return fset(self)
|
|
|
|
def __repr__(self) -> str:
|
|
return self.name
|
|
|
|
@dataclass(frozen=True)
|
|
class FunctionTy:
|
|
arg_type: MonoType
|
|
ret_type: MonoType
|
|
|
|
def subst(self, replacement: MonoType, var: VarTy) -> MonoType:
|
|
return FunctionTy(
|
|
self.arg_type.subst(replacement, var),
|
|
self.ret_type.subst(replacement, var)
|
|
)
|
|
|
|
def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]:
|
|
match other:
|
|
case VarTy(id, str) as var:
|
|
return Ok(((self, var),))
|
|
case FunctionTy(other_arg, other_ret):
|
|
return self.arg_type.unify(other_arg) << (lambda arg_substs:
|
|
subst_all_monotype(self.ret_type, arg_substs)
|
|
.unify(
|
|
subst_all_monotype(other_ret, arg_substs)) << (lambda ret_substs:
|
|
Ok((*arg_substs, *ret_substs))))
|
|
case _:
|
|
return Err(UnificationError(self, other))
|
|
|
|
def free_vars(self) -> FrozenSet[VarTy]:
|
|
return self.arg_type.free_vars() | self.ret_type.free_vars()
|
|
|
|
def __repr__(self) -> str:
|
|
return f'({self.arg_type}) -> {self.ret_type}'
|
|
|
|
@dataclass(frozen=True)
|
|
class HoleTy:
|
|
|
|
def subst(self, replacement: MonoType, var: VarTy) -> MonoType:
|
|
return self
|
|
|
|
def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]:
|
|
match other:
|
|
case VarTy(id, str) as var:
|
|
return Ok(((self, var),))
|
|
case HoleTy():
|
|
return Ok(tuple())
|
|
case _:
|
|
return Err(UnificationError(self, other))
|
|
|
|
def free_vars(self) -> FrozenSet[VarTy]:
|
|
return FSet()
|
|
|
|
def __repr__(self) -> str:
|
|
return '[]'
|
|
|
|
@dataclass(frozen=True)
|
|
class IntTy:
|
|
|
|
def subst(self, replacement: MonoType, var: VarTy) -> MonoType:
|
|
return self
|
|
|
|
def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]:
|
|
match other:
|
|
case VarTy(id, str) as var:
|
|
return Ok(((self, var),))
|
|
case IntTy():
|
|
return Ok(tuple())
|
|
case _:
|
|
return Err(UnificationError(self, other))
|
|
|
|
def free_vars(self) -> FrozenSet[VarTy]:
|
|
return FSet()
|
|
|
|
def __repr__(self) -> str:
|
|
return 'int'
|
|
|
|
@dataclass(frozen=True)
|
|
class PolyType:
|
|
quantified_variables: Collection[VarTy]
|
|
monotype: MonoType
|
|
|
|
def subst(self, replacement: MonoType, var: VarTy) -> 'PolyType':
|
|
if var in self.quantified_variables:
|
|
return self
|
|
else:
|
|
return PolyType(
|
|
self.quantified_variables,
|
|
self.monotype.subst(replacement, var)
|
|
)
|
|
|
|
def instantiate(self, ctx: 'Context', name: str) -> MonoType:
|
|
match self.quantified_variables:
|
|
case []:
|
|
return self.monotype
|
|
case [alpha, *rest]:
|
|
return PolyType(
|
|
rest,
|
|
self.monotype.subst(ctx.new_type_var(name), alpha)
|
|
).instantiate(ctx, name)
|
|
raise Exception('Unreachable')
|
|
|
|
def free_vars(self) -> FrozenSet[VarTy]:
|
|
return self.monotype.free_vars() - FSet(self.quantified_variables)
|
|
|
|
def __repr__(self) -> str:
|
|
if len(self.quantified_variables):
|
|
qv = ' '.join(qv.name for qv in self.quantified_variables)
|
|
return f'forall {qv}. {self.monotype}'
|
|
else:
|
|
return repr(self.monotype)
|
|
|
|
UNIQUE_VAR_COUNTER:int = 0
|
|
|
|
@dataclass(frozen=True)
|
|
class Context:
|
|
variable_types: Mapping[str, PolyType]
|
|
|
|
def with_(self, var_name: str, var_type: PolyType) -> 'Context':
|
|
return Context({var_name: var_type, **self.variable_types})
|
|
|
|
def with_mono(self, var_name: str, var_type: MonoType) -> 'Context':
|
|
return Context({var_name: PolyType([], var_type), **self.variable_types})
|
|
|
|
def with_many_mono(self, bindings: Sequence[tuple[str, MonoType]]) -> 'Context':
|
|
match bindings:
|
|
case []:
|
|
return self
|
|
case [(var_name, var_type), *rest]:
|
|
return self.with_mono(var_name, var_type).with_many_mono(rest)
|
|
raise Exception('Unreachable')
|
|
|
|
def subst(self, replacement: MonoType, var: VarTy) -> 'Context':
|
|
return Context({
|
|
name: val.subst(replacement, var)
|
|
for (name, val) in self.variable_types.items()
|
|
})
|
|
|
|
def subst_all(self, subst: Sequence[Tuple[MonoType, VarTy]]) -> 'Context':
|
|
return reduce(lambda a, s: a.subst(*s), subst, self)
|
|
|
|
def new_type_var(self, name_prefix = 'T') -> VarTy:
|
|
global UNIQUE_VAR_COUNTER
|
|
UNIQUE_VAR_COUNTER += 1
|
|
return VarTy(UNIQUE_VAR_COUNTER, f'{name_prefix}${UNIQUE_VAR_COUNTER}')
|
|
|
|
def instantiate(self, name: str) -> Option[MonoType]:
|
|
if name in self.variable_types:
|
|
return Some(self.variable_types[name].instantiate(self, name))
|
|
else:
|
|
return None
|
|
|
|
def free_vars(self) -> FrozenSet[VarTy]:
|
|
return FSet(fv for ty in self.variable_types.values() for fv in ty.free_vars())
|
|
|
|
def generalize(self, mt: MonoType) -> PolyType:
|
|
variable_mappings = [(VarTy(fv.id, 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'[i]), fv) for (i, fv) in enumerate(mt.free_vars())]
|
|
return PolyType([new_var for (new_var, _) in variable_mappings], subst_all_monotype(mt, variable_mappings))
|
|
|
|
def __contains__(self, name: str) -> bool:
|
|
return name in self.variable_types
|
|
|
|
def __getitem__(self, i: str) -> PolyType:
|
|
return self.variable_types[i]
|
|
|
|
BUILTINS_CONTEXT: Context = (
|
|
Context({})
|
|
.with_mono('+', FunctionTy(IntTy(), FunctionTy(IntTy(), IntTy())))
|
|
.with_mono('S', FunctionTy(IntTy(), IntTy()))
|
|
)
|
|
|
|
def subst_all_monotype(m: MonoType, substs: Sequence[Tuple[MonoType, VarTy]]) -> MonoType:
|
|
return reduce(lambda a, s: a.subst(*s), substs, m) |