JSON-Lang/types_.py
Emi Simpson ca685d7ded
Fix substitution not working for polytypes
honestly i don't know how i didn't catch that sooner
2024-03-15 20:21:43 -04:00

207 lines
5.8 KiB
Python

from emis_funky_funktions import *
from typing import *
from dataclasses import dataclass, field
from functools import reduce
MonoType: TypeAlias = 'FunctionTy | HoleTy | IntTy | VarTy'
Substitution: TypeAlias = 'tuple[MonoType, VarTy]'
@dataclass(frozen=True)
class UnificationError:
ty1: MonoType
ty2: MonoType
@dataclass(frozen=True)
class VarTy:
id: int
name: str = field(compare=False)
def subst(self, replacement: MonoType, var: 'VarTy') -> MonoType:
if var.id == self.id:
return replacement
else:
return self
def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]:
match other:
case VarTy(id, str) if id == self.id:
return Ok(tuple())
case _:
return Ok(((other, self),))
raise Exception('Unreachable')
def free_vars(self) -> FrozenSet['VarTy']:
return fset(self)
def __repr__(self) -> str:
return self.name
@dataclass(frozen=True)
class FunctionTy:
arg_type: MonoType
ret_type: MonoType
def subst(self, replacement: MonoType, var: VarTy) -> MonoType:
return FunctionTy(
self.arg_type.subst(replacement, var),
self.ret_type.subst(replacement, var)
)
def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]:
match other:
case VarTy(id, str) as var:
return Ok(((self, var),))
case FunctionTy(other_arg, other_ret):
return self.arg_type.unify(other_arg) << (lambda arg_substs:
subst_all_monotype(self.ret_type, arg_substs)
.unify(
subst_all_monotype(other_ret, arg_substs)) << (lambda ret_substs:
Ok((*arg_substs, *ret_substs))))
case _:
return Err(UnificationError(self, other))
def free_vars(self) -> FrozenSet[VarTy]:
return self.arg_type.free_vars() | self.ret_type.free_vars()
def __repr__(self) -> str:
return f'({self.arg_type}) -> {self.ret_type}'
@dataclass(frozen=True)
class HoleTy:
def subst(self, replacement: MonoType, var: VarTy) -> MonoType:
return self
def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]:
match other:
case VarTy(id, str) as var:
return Ok(((self, var),))
case HoleTy():
return Ok(tuple())
case _:
return Err(UnificationError(self, other))
def free_vars(self) -> FrozenSet[VarTy]:
return FSet()
def __repr__(self) -> str:
return '[]'
@dataclass(frozen=True)
class IntTy:
def subst(self, replacement: MonoType, var: VarTy) -> MonoType:
return self
def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]:
match other:
case VarTy(id, str) as var:
return Ok(((self, var),))
case IntTy():
return Ok(tuple())
case _:
return Err(UnificationError(self, other))
def free_vars(self) -> FrozenSet[VarTy]:
return FSet()
def __repr__(self) -> str:
return 'int'
@dataclass(frozen=True)
class PolyType:
quantified_variables: Collection[VarTy]
monotype: MonoType
def subst(self, replacement: MonoType, var: VarTy) -> 'PolyType':
if var in self.quantified_variables:
return self
else:
return PolyType(
self.quantified_variables,
self.monotype.subst(replacement, var)
)
def instantiate(self, ctx: 'Context', name: str) -> MonoType:
match self.quantified_variables:
case []:
return self.monotype
case [alpha, *rest]:
return PolyType(
rest,
self.monotype.subst(ctx.new_type_var(name), alpha)
).instantiate(ctx, name)
raise Exception('Unreachable')
def free_vars(self) -> FrozenSet[VarTy]:
return self.monotype.free_vars() - FSet(self.quantified_variables)
def __repr__(self) -> str:
if len(self.quantified_variables):
qv = ' '.join(qv.name for qv in self.quantified_variables)
return f'forall {qv}. {self.monotype}'
else:
return repr(self.monotype)
UNIQUE_VAR_COUNTER:int = 0
@dataclass(frozen=True)
class Context:
variable_types: Mapping[str, PolyType]
def with_(self, var_name: str, var_type: PolyType) -> 'Context':
return Context({var_name: var_type, **self.variable_types})
def with_mono(self, var_name: str, var_type: MonoType) -> 'Context':
return Context({var_name: PolyType([], var_type), **self.variable_types})
def with_many_mono(self, bindings: Sequence[tuple[str, MonoType]]) -> 'Context':
match bindings:
case []:
return self
case [(var_name, var_type), *rest]:
return self.with_mono(var_name, var_type).with_many_mono(rest)
raise Exception('Unreachable')
def subst(self, replacement: MonoType, var: VarTy) -> 'Context':
return Context({
name: val.subst(replacement, var)
for (name, val) in self.variable_types.items()
})
def subst_all(self, subst: Sequence[Tuple[MonoType, VarTy]]) -> 'Context':
return reduce(lambda a, s: a.subst(*s), subst, self)
def new_type_var(self, name_prefix = 'T') -> VarTy:
global UNIQUE_VAR_COUNTER
UNIQUE_VAR_COUNTER += 1
return VarTy(UNIQUE_VAR_COUNTER, f'{name_prefix}${UNIQUE_VAR_COUNTER}')
def instantiate(self, name: str) -> Option[MonoType]:
if name in self.variable_types:
return Some(self.variable_types[name].instantiate(self, name))
else:
return None
def free_vars(self) -> FrozenSet[VarTy]:
return FSet(fv for ty in self.variable_types.values() for fv in ty.free_vars())
def generalize(self, mt: MonoType) -> PolyType:
variable_mappings = [(VarTy(fv.id, 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'[i]), fv) for (i, fv) in enumerate(mt.free_vars())]
return PolyType([new_var for (new_var, _) in variable_mappings], subst_all_monotype(mt, variable_mappings))
def __contains__(self, name: str) -> bool:
return name in self.variable_types
def __getitem__(self, i: str) -> PolyType:
return self.variable_types[i]
BUILTINS_CONTEXT: Context = (
Context({})
.with_mono('+', FunctionTy(IntTy(), FunctionTy(IntTy(), IntTy())))
.with_mono('S', FunctionTy(IntTy(), IntTy()))
)
def subst_all_monotype(m: MonoType, substs: Sequence[Tuple[MonoType, VarTy]]) -> MonoType:
return reduce(lambda a, s: a.subst(*s), substs, m)