from emis_funky_funktions import * from typing import Sequence, Mapping from pattern import lex_and_parse_pattern from ir import Expression, Function, Application, Int, Variable, LetBinding, ReplHole from patterns import Pattern, NamePattern, IgnorePattern, IntPattern, SPattern from types_ import * from functools import reduce import json JsonType: TypeAlias = 'Mapping[str, JsonType] | Sequence[JsonType] | int | str' SemanticError: TypeAlias = 'PatternParseProblem | UnificationError | BranchTypesDiffer | UndefinedVariable' @dataclass(frozen=True) class PatternParseProblem: pattern: str @dataclass(frozen=True) class PatternTypeMismatch: pattern: Pattern real_type: PolyType required_type: PolyType @dataclass(frozen=True) class UndefinedVariable: name: str @dataclass(frozen=True) class BranchTypesDiffer: branch_1_pattern: str branch_2_pattern: str branch_1_type: MonoType branch_2_type: MonoType def pattern_type_bindings(pattern: Pattern, ctx: Context) -> Result[Tuple[MonoType, Sequence[Tuple[str, MonoType]]], PatternTypeMismatch]: match pattern: case NamePattern(name): new_ty = ctx.new_type_var() return Ok((new_ty, ((name, new_ty),))) case IgnorePattern(): new_ty = ctx.new_type_var() return Ok((new_ty, tuple())) case IntPattern(): return Ok((IntTy(), tuple())) case SPattern(pred): match pattern_type_bindings(pred, ctx): case Ok((desc_ty, desc_bindings)): match desc_ty.unify(IntTy()): case Ok(_): return Ok((IntTy(), desc_bindings)) case Err(_): return Err(PatternTypeMismatch(pred, ctx.generalize(desc_ty), PolyType([], IntTy()))) case Err(_) as e: return e raise Exception('Unreachable') def branches_2_ir( branches: Sequence[Tuple[str, JsonType]], context: Context, ) -> Result[tuple[Sequence[tuple[Pattern, Expression]], MonoType, Sequence[Substitution]], SemanticError]: # TOdO Affirm argument type with pattern match branches: case [(raw_pattern, raw_expr), *rest_branches_raw]: # Parse the pattern match lex_and_parse_pattern(raw_pattern): case Ok(parsed_pattern): # Type the pattern match pattern_type_bindings(parsed_pattern, context): case Ok((pattern_accept_type, bindings)): # Parse the body context_with_pattern_bindings = context.with_many_mono(bindings) match json_to_ir(raw_expr, context.with_many_mono(bindings)): case Ok(( parsed_expr, branch_ty, branch_subs )): updated_ctx = context.subst_all(branch_subs) # Compute the type of this branch this_branch = FunctionTy(subst_all_monotype(pattern_accept_type, branch_subs), branch_ty) # Parse the rest of the branches match branches_2_ir(rest_branches_raw, updated_ctx): case Ok(( rest_branches, rest_ty, rest_substs )): # Unify this branch and the rest of the branches match subst_all_monotype(this_branch, rest_substs).unify(rest_ty): case Ok(branch_unif_subst): return Ok(( ((parsed_pattern, parsed_expr), *rest_branches), subst_all_monotype(rest_ty, branch_unif_subst), (*branch_subs, *rest_substs, *branch_unif_subst) )) case Err(unif_err): # This branch's type disagrees with the rest of the branches return Err(BranchTypesDiffer(raw_pattern, rest_branches_raw[0][0], this_branch, rest_ty)) case Err(e): # Problem parsing one of the remaining branches return Err(e) case Err(_) as e: # Problem parsing expression return e case Err(_) as e: # Pattern type mismatch return e case Err(_): # Problem parsing pattern return Err(PatternParseProblem(raw_pattern)) pass case []: return Ok(( tuple(), context.new_type_var(), tuple() )) raise Exception('Unreachable, I hope') #god why can't mypy check this def seq_nonmt_2_ir( first_expr: Expression, first_ty: MonoType, right: Sequence[JsonType], context: Context ) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]: match right: case [arg, *rest]: match json_to_ir(arg, context): case Ok(( arg_expr, arg_ty, arg_substs )): ret_ty = context.new_type_var() match first_ty.unify(FunctionTy(arg_ty, ret_ty)): case Ok(unification_substs): updated_ctx = context.subst_all(arg_substs).subst_all(unification_substs) return seq_nonmt_2_ir( Application(first_expr, arg_expr), subst_all_monotype(ret_ty, unification_substs), rest, updated_ctx ) <= (lambda expr__ty__subst: ( expr__ty__subst[0] , expr__ty__subst[1] , (*arg_substs, *unification_substs, *expr__ty__subst[2]) ) ) case Err(_) as e: return e case Err(_) as e: return e case []: return Ok(( first_expr, first_ty, tuple() )) raise Exception('Unreachable') def let_2_ir( lhs: str, rhs: JsonType, body: Sequence[JsonType], context: Context, ) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]: # Parse the rhs standin_ty = context.new_type_var() context_for_rhs = context.with_mono(lhs, standin_ty) if isinstance(rhs, Mapping) else context match json_to_ir(rhs, context_for_rhs): case Ok(( rhs_expr, rhs_ty, rhs_subst )): # Unify the rhs type with the generated type of the rhs from earlier match subst_all_monotype(standin_ty, rhs_subst).unify(rhs_ty): case Ok(recursion_substs): updated_ctx = context.subst_all(rhs_subst).with_(lhs, context.generalize(rhs_ty)).subst_all(recursion_substs) # Parse the body match json_to_ir(body, updated_ctx): case Ok(( body_expr, body_ty, body_substs )): return Ok(( LetBinding(lhs, rhs_expr, body_expr), body_ty, (*rhs_subst, *recursion_substs, *body_substs) )) case Err(_) as e: return e case Err(_) as e: return e case Err(_) as e: return e raise Exception('Unreachable') def json_to_ir(j: JsonType, type_ctx: Context) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]: if isinstance(j, Mapping): return branches_2_ir(tuple(j.items()), type_ctx) <= (lambda ir_ty_subst: (Function(ir_ty_subst[0]), *ir_ty_subst[1:])) elif isinstance(j, str): match type_ctx.instantiate(j): case Some(j_type): return Ok(( Variable(j), j_type, tuple() )) case None: return Err(UndefinedVariable(j)) raise Exception('Unreachable') elif isinstance(j, Sequence): match j: case [first, *rest]: match json_to_ir(first, type_ctx): case Ok(( fst, fst_ty, fst_substs)): # Application updated_ctx = type_ctx.subst_all(fst_substs) return seq_nonmt_2_ir(fst, fst_ty, rest, updated_ctx) <= (lambda exp__ty__subs: (exp__ty__subs[0], exp__ty__subs[1], (*fst_substs, *exp__ty__subs[2])) ) case Err(UndefinedVariable(v)) if isinstance(first, str) and v == first: # Let or String match rest: case [rhs, *body]: # Let return let_2_ir(first, rhs, body, type_ctx) case []: # String raise Exception('TODO: Strings') case []: return Ok(( ReplHole(type_ctx), HoleTy(), tuple() )) raise Exception('Unreachable') else: return Ok(( Int(j), IntTy(), tuple() ))