JSON-Lang/genir.py

214 lines
7.4 KiB
Python

from emis_funky_funktions import *
from typing import Sequence, Mapping
from pattern import lex_and_parse_pattern
from ir import Expression, MonoFunc, Application, Int, Variable, LetBinding, ReplHole
from patterns import Pattern, NamePattern, IgnorePattern, IntPattern, SPattern
from match_tree import MatchException
from types_ import *
from functools import reduce
import json
JsonType: TypeAlias = 'Mapping[str, JsonType] | Sequence[JsonType] | int | str'
SemanticError: TypeAlias = 'PatternParseProblem | UnificationError | BranchTypesDiffer | UndefinedVariable | MatchException'
@dataclass(frozen=True)
class PatternParseProblem:
pattern: str
@dataclass(frozen=True)
class PatternTypeMismatch:
pattern: Pattern
real_type: PolyType
required_type: PolyType
@dataclass(frozen=True)
class UndefinedVariable:
name: str
@dataclass(frozen=True)
class BranchTypesDiffer:
branch_1_pattern: str
branch_2_pattern: str
branch_1_type: MonoType
branch_2_type: MonoType
def pattern_type_bindings(pattern: Pattern, ctx: Context) -> Result[Tuple[MonoType, Sequence[Tuple[str, MonoType]]], PatternTypeMismatch]:
match pattern:
case NamePattern(name):
new_ty = ctx.new_type_var()
return Ok((new_ty, ((name, new_ty),)))
case IgnorePattern():
new_ty = ctx.new_type_var()
return Ok((new_ty, tuple()))
case IntPattern():
return Ok((IntTy(), tuple()))
case SPattern(pred):
match pattern_type_bindings(pred, ctx):
case Ok((desc_ty, desc_bindings)):
match desc_ty.unify(IntTy()):
case Ok(_):
return Ok((IntTy(), desc_bindings))
case Err(_):
return Err(PatternTypeMismatch(pred, ctx.generalize(desc_ty), PolyType([], IntTy())))
case Err(_) as e:
return e
raise Exception('Unreachable')
def branches_2_ir(
branches: Sequence[Tuple[str, JsonType]],
context: Context,
) -> Result[tuple[Sequence[tuple[Pattern, Expression]], MonoType, Sequence[Substitution]], SemanticError]:
# TOdO Affirm argument type with pattern
match branches:
case [(raw_pattern, raw_expr), *rest_branches_raw]:
# Parse the pattern
match lex_and_parse_pattern(raw_pattern):
case Ok(parsed_pattern):
# Type the pattern
match pattern_type_bindings(parsed_pattern, context):
case Ok((pattern_accept_type, bindings)):
# Parse the body
context_with_pattern_bindings = context.with_many_mono(bindings)
match json_to_ir(raw_expr, context.with_many_mono(bindings)):
case Ok(( parsed_expr, branch_ty, branch_subs )):
updated_ctx = context.subst_all(branch_subs)
# Compute the type of this branch
this_branch = FunctionTy(subst_all_monotype(pattern_accept_type, branch_subs), branch_ty)
# Parse the rest of the branches
match branches_2_ir(rest_branches_raw, updated_ctx):
case Ok(( rest_branches, rest_ty, rest_substs )):
# Unify this branch and the rest of the branches
match subst_all_monotype(this_branch, rest_substs).unify(rest_ty):
case Ok(branch_unif_subst):
return Ok((
((parsed_pattern, parsed_expr), *rest_branches),
subst_all_monotype(rest_ty, branch_unif_subst),
(*branch_subs, *rest_substs, *branch_unif_subst)
))
case Err(unif_err): # This branch's type disagrees with the rest of the branches
return Err(BranchTypesDiffer(raw_pattern, rest_branches_raw[0][0], this_branch, rest_ty))
case Err(e): # Problem parsing one of the remaining branches
return Err(e)
case Err(_) as e: # Problem parsing expression
return e
case Err(_) as e: # Pattern type mismatch
return e
case Err(_): # Problem parsing pattern
return Err(PatternParseProblem(raw_pattern))
pass
case []:
return Ok(( tuple(), context.new_type_var(), tuple() ))
raise Exception('Unreachable, I hope') #god why can't mypy check this
def seq_nonmt_2_ir(
first_expr: Expression,
first_ty: MonoType,
right: Sequence[JsonType],
context: Context
) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
match right:
case [arg, *rest]:
match json_to_ir(arg, context):
case Ok(( arg_expr, arg_ty, arg_substs )):
ret_ty = context.new_type_var()
match first_ty.unify(FunctionTy(arg_ty, ret_ty)):
case Ok(unification_substs):
updated_ctx = context.subst_all(arg_substs).subst_all(unification_substs)
return seq_nonmt_2_ir(
Application(first_expr, arg_expr),
subst_all_monotype(ret_ty, unification_substs),
rest,
updated_ctx
) <= (lambda expr__ty__subst:
( expr__ty__subst[0]
, expr__ty__subst[1]
, (*arg_substs, *unification_substs, *expr__ty__subst[2])
)
)
case Err(_) as e:
return e
case Err(_) as e:
return e
case []:
return Ok(( first_expr, first_ty, tuple() ))
raise Exception('Unreachable')
def let_2_ir(
lhs: str,
rhs: JsonType,
body: Sequence[JsonType],
context: Context,
) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
# Parse the rhs
standin_ty = context.new_type_var()
context_for_rhs = context.with_mono(lhs, standin_ty) if isinstance(rhs, Mapping) else context
match json_to_ir(rhs, context_for_rhs):
case Ok(( rhs_expr, rhs_ty, rhs_subst )):
# Unify the rhs type with the generated type of the rhs from earlier
match subst_all_monotype(standin_ty, rhs_subst).unify(rhs_ty):
case Ok(recursion_substs):
updated_ctx = context.subst_all(rhs_subst).with_(lhs, context.generalize(rhs_ty)).subst_all(recursion_substs)
# Parse the body
match json_to_ir(body, updated_ctx):
case Ok(( body_expr, body_ty, body_substs )):
return Ok(( LetBinding(lhs, rhs_expr, body_expr), body_ty, (*rhs_subst, *recursion_substs, *body_substs) ))
case Err(_) as e:
return e
case Err(_) as e:
return e
case Err(_) as e:
return e
raise Exception('Unreachable')
def json_to_ir(j: JsonType, type_ctx: Context) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
if isinstance(j, Mapping):
match branches_2_ir(tuple(j.items()), type_ctx):
case Ok((branches, ty, substitutions)):
match MonoFunc.from_match_function(branches):
case Ok(monofunc):
return Ok((monofunc, ty, substitutions))
case Err(e):
return Err(e)
case Err(e):
return Err(e)
raise Exception('Unreachable')
elif isinstance(j, str):
match type_ctx.instantiate(j):
case Some(j_type):
return Ok(( Variable(j), j_type, tuple() ))
case None:
return Err(UndefinedVariable(j))
raise Exception('Unreachable')
elif isinstance(j, Sequence):
match j:
case [first, *rest]:
match json_to_ir(first, type_ctx):
case Ok(( fst, fst_ty, fst_substs)):
# Application
updated_ctx = type_ctx.subst_all(fst_substs)
return seq_nonmt_2_ir(fst, fst_ty, rest, updated_ctx) <= (lambda exp__ty__subs:
(exp__ty__subs[0], exp__ty__subs[1], (*fst_substs, *exp__ty__subs[2]))
)
case Err(UndefinedVariable(v)) if isinstance(first, str) and v == first:
# Let or String
match rest:
case [rhs, *body]:
# Let
return let_2_ir(first, rhs, body, type_ctx)
case []:
# String
raise Exception('TODO: Strings')
case []:
return Ok(( ReplHole(type_ctx), HoleTy(), tuple() ))
raise Exception('Unreachable')
else:
return Ok(( Int(j), IntTy(), tuple() ))