Make typechecking slightly less buggy

- The statement ["x", "x"] no longer typechecks in the empty context
- The statement ["ident": {"x": "x"}] now correctly generalizes
- The statement ["compose", "compose", "compose"] no longer goes into an infinite loop
This commit is contained in:
Emi Simpson 2023-03-09 17:48:57 -05:00
parent 792674797d
commit fb55346003
Signed by: Emi
GPG key ID: A12F2C2FFDC3D847
2 changed files with 5 additions and 3 deletions

View file

@ -146,13 +146,14 @@ def let_2_ir(
) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]: ) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
# Parse the rhs # Parse the rhs
standin_ty = context.new_type_var() standin_ty = context.new_type_var()
match json_to_ir(rhs, context.with_mono(lhs, standin_ty)): context_for_rhs = context.with_mono(lhs, standin_ty) if isinstance(rhs, Mapping) else context
match json_to_ir(rhs, context_for_rhs):
case Ok(( rhs_expr, rhs_ty, rhs_subst )): case Ok(( rhs_expr, rhs_ty, rhs_subst )):
# Unify the rhs type with the generated type of the rhs from earlier # Unify the rhs type with the generated type of the rhs from earlier
match subst_all_monotype(standin_ty, rhs_subst).unify(rhs_ty): match subst_all_monotype(standin_ty, rhs_subst).unify(rhs_ty):
case Ok(recursion_substs): case Ok(recursion_substs):
updated_ctx = context.subst_all(rhs_subst).with_mono(lhs, rhs_ty).subst_all(recursion_substs) updated_ctx = context.subst_all(rhs_subst).with_(lhs, context.generalize(rhs_ty)).subst_all(recursion_substs)
# Parse the body # Parse the body
match json_to_ir(body, updated_ctx): match json_to_ir(body, updated_ctx):

View file

@ -188,7 +188,8 @@ class Context:
return FSet(fv for ty in self.variable_types.values() for fv in ty.free_vars()) return FSet(fv for ty in self.variable_types.values() for fv in ty.free_vars())
def generalize(self, mt: MonoType) -> PolyType: def generalize(self, mt: MonoType) -> PolyType:
return PolyType(mt.free_vars(), mt) variable_mappings = [(VarTy(fv.id, 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'[i]), fv) for (i, fv) in enumerate(mt.free_vars())]
return PolyType([new_var for (new_var, _) in variable_mappings], subst_all_monotype(mt, variable_mappings))
def __contains__(self, name: str) -> bool: def __contains__(self, name: str) -> bool:
return name in self.variable_types return name in self.variable_types