from emis_funky_funktions import * from typing import * from dataclasses import dataclass, field from functools import reduce MonoType: TypeAlias = 'FunctionTy | HoleTy | IntTy | VarTy' Substitution: TypeAlias = 'tuple[MonoType, VarTy]' @dataclass(frozen=True) class UnificationError: ty1: MonoType ty2: MonoType @dataclass(frozen=True) class VarTy: id: int name: str = field(compare=False) def subst(self, replacement: MonoType, var: 'VarTy') -> MonoType: if var.id == self.id: return replacement else: return self def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]: match other: case VarTy(id, str) if id == self.id: return Ok(tuple()) case _: return Ok(((other, self),)) raise Exception('Unreachable') def free_vars(self) -> FrozenSet['VarTy']: return fset(self) def __repr__(self) -> str: return self.name @dataclass(frozen=True) class FunctionTy: arg_type: MonoType ret_type: MonoType def subst(self, replacement: MonoType, var: VarTy) -> MonoType: return FunctionTy( self.arg_type.subst(replacement, var), self.ret_type.subst(replacement, var) ) def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]: match other: case VarTy(id, str) as var: return Ok(((self, var),)) case FunctionTy(other_arg, other_ret): return self.arg_type.unify(other_arg) << (lambda arg_substs: subst_all_monotype(self.ret_type, arg_substs) .unify( subst_all_monotype(other_ret, arg_substs)) << (lambda ret_substs: Ok((*arg_substs, *ret_substs)))) case _: return Err(UnificationError(self, other)) def free_vars(self) -> FrozenSet[VarTy]: return self.arg_type.free_vars() | self.ret_type.free_vars() def __repr__(self) -> str: return f'({self.arg_type}) -> {self.ret_type}' @dataclass(frozen=True) class HoleTy: def subst(self, replacement: MonoType, var: VarTy) -> MonoType: return self def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]: match other: case VarTy(id, str) as var: return Ok(((self, var),)) case HoleTy(): return Ok(tuple()) case _: return Err(UnificationError(self, other)) def free_vars(self) -> FrozenSet[VarTy]: return FSet() def __repr__(self) -> str: return '[]' @dataclass(frozen=True) class IntTy: def subst(self, replacement: MonoType, var: VarTy) -> MonoType: return self def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]: match other: case VarTy(id, str) as var: return Ok(((self, var),)) case IntTy(): return Ok(tuple()) case _: return Err(UnificationError(self, other)) def free_vars(self) -> FrozenSet[VarTy]: return FSet() def __repr__(self) -> str: return 'int' @dataclass(frozen=True) class PolyType: quantified_variables: Collection[VarTy] monotype: MonoType def subst(self, replacement: MonoType, var: VarTy) -> 'PolyType': if var in self.quantified_variables: return self else: return PolyType( self.quantified_variables, self.monotype.subst(replacement, var) ) def instantiate(self, ctx: 'Context', name: str) -> MonoType: match self.quantified_variables: case []: return self.monotype case [alpha, *rest]: return PolyType( rest, self.monotype.subst(ctx.new_type_var(name), alpha) ).instantiate(ctx, name) raise Exception('Unreachable') def free_vars(self) -> FrozenSet[VarTy]: return self.monotype.free_vars() - FSet(self.quantified_variables) def __repr__(self) -> str: if len(self.quantified_variables): qv = ' '.join(qv.name for qv in self.quantified_variables) return f'forall {qv}. {self.monotype}' else: return repr(self.monotype) UNIQUE_VAR_COUNTER:int = 0 @dataclass(frozen=True) class Context: variable_types: Mapping[str, PolyType] def with_(self, var_name: str, var_type: PolyType) -> 'Context': return Context({var_name: var_type, **self.variable_types}) def with_mono(self, var_name: str, var_type: MonoType) -> 'Context': return Context({var_name: PolyType([], var_type), **self.variable_types}) def with_many_mono(self, bindings: Sequence[tuple[str, MonoType]]) -> 'Context': match bindings: case []: return self case [(var_name, var_type), *rest]: return self.with_mono(var_name, var_type).with_many_mono(rest) raise Exception('Unreachable') def subst(self, replacement: MonoType, var: VarTy) -> 'Context': return Context({ name: val.subst(replacement, var) for (name, val) in self.variable_types.items() }) def subst_all(self, subst: Sequence[Tuple[MonoType, VarTy]]) -> 'Context': return reduce(lambda a, s: a.subst(*s), subst, self) def new_type_var(self, name_prefix = 'T') -> VarTy: global UNIQUE_VAR_COUNTER UNIQUE_VAR_COUNTER += 1 return VarTy(UNIQUE_VAR_COUNTER, f'{name_prefix}${UNIQUE_VAR_COUNTER}') def instantiate(self, name: str) -> Option[MonoType]: if name in self.variable_types: return Some(self.variable_types[name].instantiate(self, name)) else: return None def free_vars(self) -> FrozenSet[VarTy]: return FSet(fv for ty in self.variable_types.values() for fv in ty.free_vars()) def generalize(self, mt: MonoType) -> PolyType: variable_mappings = [(VarTy(fv.id, 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'[i]), fv) for (i, fv) in enumerate(mt.free_vars())] return PolyType([new_var for (new_var, _) in variable_mappings], subst_all_monotype(mt, variable_mappings)) def __contains__(self, name: str) -> bool: return name in self.variable_types def __getitem__(self, i: str) -> PolyType: return self.variable_types[i] BUILTINS_CONTEXT: Context = ( Context({}) .with_mono('+', FunctionTy(IntTy(), FunctionTy(IntTy(), IntTy()))) .with_mono('S', FunctionTy(IntTy(), IntTy())) ) def subst_all_monotype(m: MonoType, substs: Sequence[Tuple[MonoType, VarTy]]) -> MonoType: return reduce(lambda a, s: a.subst(*s), substs, m)