205 lines
7.1 KiB
Python
205 lines
7.1 KiB
Python
from emis_funky_funktions import *
|
|
from typing import Sequence, Mapping
|
|
|
|
from pattern import lex_and_parse_pattern
|
|
from ir import Expression, Function, Application, Int, Variable, LetBinding, ReplHole
|
|
from patterns import Pattern, NamePattern, IgnorePattern, IntPattern, SPattern
|
|
from types_ import *
|
|
|
|
from functools import reduce
|
|
import json
|
|
|
|
JsonType: TypeAlias = 'Mapping[str, JsonType] | Sequence[JsonType] | int | str'
|
|
SemanticError: TypeAlias = 'PatternParseProblem | UnificationError | BranchTypesDiffer | UndefinedVariable'
|
|
|
|
@dataclass(frozen=True)
|
|
class PatternParseProblem:
|
|
pattern: str
|
|
|
|
@dataclass(frozen=True)
|
|
class PatternTypeMismatch:
|
|
pattern: Pattern
|
|
real_type: PolyType
|
|
required_type: PolyType
|
|
|
|
@dataclass(frozen=True)
|
|
class UndefinedVariable:
|
|
name: str
|
|
|
|
@dataclass(frozen=True)
|
|
class BranchTypesDiffer:
|
|
branch_1_pattern: str
|
|
branch_2_pattern: str
|
|
branch_1_type: MonoType
|
|
branch_2_type: MonoType
|
|
|
|
def pattern_type_bindings(pattern: Pattern, ctx: Context) -> Result[Tuple[MonoType, Sequence[Tuple[str, MonoType]]], PatternTypeMismatch]:
|
|
match pattern:
|
|
case NamePattern(name):
|
|
new_ty = ctx.new_type_var()
|
|
return Ok((new_ty, ((name, new_ty),)))
|
|
case IgnorePattern():
|
|
new_ty = ctx.new_type_var()
|
|
return Ok((new_ty, tuple()))
|
|
case IntPattern():
|
|
return Ok((IntTy(), tuple()))
|
|
case SPattern(pred):
|
|
match pattern_type_bindings(pred, ctx):
|
|
case Ok((desc_ty, desc_bindings)):
|
|
match desc_ty.unify(IntTy()):
|
|
case Ok(_):
|
|
return Ok((IntTy(), desc_bindings))
|
|
case Err(_):
|
|
return Err(PatternTypeMismatch(pred, ctx.generalize(desc_ty), PolyType([], IntTy())))
|
|
case Err(_) as e:
|
|
return e
|
|
raise Exception('Unreachable')
|
|
|
|
def branches_2_ir(
|
|
branches: Sequence[Tuple[str, JsonType]],
|
|
context: Context,
|
|
) -> Result[tuple[Sequence[tuple[Pattern, Expression]], MonoType, Sequence[Substitution]], SemanticError]:
|
|
# TOdO Affirm argument type with pattern
|
|
match branches:
|
|
case [(raw_pattern, raw_expr), *rest_branches_raw]:
|
|
# Parse the pattern
|
|
match lex_and_parse_pattern(raw_pattern):
|
|
case Ok(parsed_pattern):
|
|
|
|
# Type the pattern
|
|
match pattern_type_bindings(parsed_pattern, context):
|
|
case Ok((pattern_accept_type, bindings)):
|
|
|
|
# Parse the body
|
|
context_with_pattern_bindings = context.with_many_mono(bindings)
|
|
match json_to_ir(raw_expr, context.with_many_mono(bindings)):
|
|
case Ok(( parsed_expr, branch_ty, branch_subs )):
|
|
updated_ctx = context.subst_all(branch_subs)
|
|
|
|
# Compute the type of this branch
|
|
this_branch = FunctionTy(subst_all_monotype(pattern_accept_type, branch_subs), branch_ty)
|
|
|
|
# Parse the rest of the branches
|
|
match branches_2_ir(rest_branches_raw, updated_ctx):
|
|
case Ok(( rest_branches, rest_ty, rest_substs )):
|
|
|
|
# Unify this branch and the rest of the branches
|
|
match subst_all_monotype(this_branch, rest_substs).unify(rest_ty):
|
|
case Ok(branch_unif_subst):
|
|
return Ok((
|
|
((parsed_pattern, parsed_expr), *rest_branches),
|
|
subst_all_monotype(rest_ty, branch_unif_subst),
|
|
(*branch_subs, *rest_substs, *branch_unif_subst)
|
|
))
|
|
case Err(unif_err): # This branch's type disagrees with the rest of the branches
|
|
return Err(BranchTypesDiffer(raw_pattern, rest_branches_raw[0][0], this_branch, rest_ty))
|
|
case Err(e): # Problem parsing one of the remaining branches
|
|
return Err(e)
|
|
case Err(_) as e: # Problem parsing expression
|
|
return e
|
|
case Err(_) as e: # Pattern type mismatch
|
|
return e
|
|
case Err(_): # Problem parsing pattern
|
|
return Err(PatternParseProblem(raw_pattern))
|
|
pass
|
|
case []:
|
|
return Ok(( tuple(), context.new_type_var(), tuple() ))
|
|
raise Exception('Unreachable, I hope') #god why can't mypy check this
|
|
|
|
def seq_nonmt_2_ir(
|
|
first_expr: Expression,
|
|
first_ty: MonoType,
|
|
right: Sequence[JsonType],
|
|
context: Context
|
|
) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
|
|
match right:
|
|
case [arg, *rest]:
|
|
match json_to_ir(arg, context):
|
|
case Ok(( arg_expr, arg_ty, arg_substs )):
|
|
ret_ty = context.new_type_var()
|
|
match first_ty.unify(FunctionTy(arg_ty, ret_ty)):
|
|
case Ok(unification_substs):
|
|
updated_ctx = context.subst_all(arg_substs).subst_all(unification_substs)
|
|
return seq_nonmt_2_ir(
|
|
Application(first_expr, arg_expr),
|
|
subst_all_monotype(ret_ty, unification_substs),
|
|
rest,
|
|
updated_ctx
|
|
) <= (lambda expr__ty__subst:
|
|
( expr__ty__subst[0]
|
|
, expr__ty__subst[1]
|
|
, (*arg_substs, *unification_substs, *expr__ty__subst[2])
|
|
)
|
|
)
|
|
case Err(_) as e:
|
|
return e
|
|
case Err(_) as e:
|
|
return e
|
|
case []:
|
|
return Ok(( first_expr, first_ty, tuple() ))
|
|
raise Exception('Unreachable')
|
|
|
|
def let_2_ir(
|
|
lhs: str,
|
|
rhs: JsonType,
|
|
body: Sequence[JsonType],
|
|
context: Context,
|
|
) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
|
|
# Parse the rhs
|
|
standin_ty = context.new_type_var()
|
|
context_for_rhs = context.with_mono(lhs, standin_ty) if isinstance(rhs, Mapping) else context
|
|
match json_to_ir(rhs, context_for_rhs):
|
|
case Ok(( rhs_expr, rhs_ty, rhs_subst )):
|
|
|
|
# Unify the rhs type with the generated type of the rhs from earlier
|
|
match subst_all_monotype(standin_ty, rhs_subst).unify(rhs_ty):
|
|
case Ok(recursion_substs):
|
|
updated_ctx = context.subst_all(rhs_subst).with_(lhs, context.generalize(rhs_ty)).subst_all(recursion_substs)
|
|
|
|
# Parse the body
|
|
match json_to_ir(body, updated_ctx):
|
|
case Ok(( body_expr, body_ty, body_substs )):
|
|
return Ok(( LetBinding(lhs, rhs_expr, body_expr), body_ty, (*rhs_subst, *recursion_substs, *body_substs) ))
|
|
case Err(_) as e:
|
|
return e
|
|
case Err(_) as e:
|
|
return e
|
|
case Err(_) as e:
|
|
return e
|
|
raise Exception('Unreachable')
|
|
|
|
def json_to_ir(j: JsonType, type_ctx: Context) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
|
|
if isinstance(j, Mapping):
|
|
return branches_2_ir(tuple(j.items()), type_ctx) <= (lambda ir_ty_subst:
|
|
(Function(ir_ty_subst[0]), *ir_ty_subst[1:]))
|
|
elif isinstance(j, str):
|
|
match type_ctx.instantiate(j):
|
|
case Some(j_type):
|
|
return Ok(( Variable(j), j_type, tuple() ))
|
|
case None:
|
|
return Err(UndefinedVariable(j))
|
|
raise Exception('Unreachable')
|
|
elif isinstance(j, Sequence):
|
|
match j:
|
|
case [first, *rest]:
|
|
match json_to_ir(first, type_ctx):
|
|
case Ok(( fst, fst_ty, fst_substs)):
|
|
# Application
|
|
updated_ctx = type_ctx.subst_all(fst_substs)
|
|
return seq_nonmt_2_ir(fst, fst_ty, rest, updated_ctx) <= (lambda exp__ty__subs:
|
|
(exp__ty__subs[0], exp__ty__subs[1], (*fst_substs, *exp__ty__subs[2]))
|
|
)
|
|
case Err(UndefinedVariable(v)) if isinstance(first, str) and v == first:
|
|
# Let or String
|
|
match rest:
|
|
case [rhs, *body]:
|
|
# Let
|
|
return let_2_ir(first, rhs, body, type_ctx)
|
|
case []:
|
|
# String
|
|
raise Exception('TODO: Strings')
|
|
case []:
|
|
return Ok(( ReplHole(type_ctx), HoleTy(), tuple() ))
|
|
raise Exception('Unreachable')
|
|
else:
|
|
return Ok(( Int(j), IntTy(), tuple() )) |