diff --git a/genir.py b/genir.py index 9bd5102..1f302ab 100644 --- a/genir.py +++ b/genir.py @@ -2,15 +2,16 @@ from emis_funky_funktions import * from typing import Sequence, Mapping from pattern import lex_and_parse_pattern -from ir import Expression, Function, Application, Int, Variable, LetBinding, ReplHole +from ir import Expression, MonoFunc, Application, Int, Variable, LetBinding, ReplHole from patterns import Pattern, NamePattern, IgnorePattern, IntPattern, SPattern +from match_tree import MatchException from types_ import * from functools import reduce import json JsonType: TypeAlias = 'Mapping[str, JsonType] | Sequence[JsonType] | int | str' -SemanticError: TypeAlias = 'PatternParseProblem | UnificationError | BranchTypesDiffer | UndefinedVariable' +SemanticError: TypeAlias = 'PatternParseProblem | UnificationError | BranchTypesDiffer | UndefinedVariable | MatchException' @dataclass(frozen=True) class PatternParseProblem: @@ -170,8 +171,16 @@ def let_2_ir( def json_to_ir(j: JsonType, type_ctx: Context) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]: if isinstance(j, Mapping): - return branches_2_ir(tuple(j.items()), type_ctx) <= (lambda ir_ty_subst: - (Function(ir_ty_subst[0]), *ir_ty_subst[1:])) + match branches_2_ir(tuple(j.items()), type_ctx): + case Ok((branches, ty, substitutions)): + match MonoFunc.from_match_function(branches): + case Ok(monofunc): + return Ok((monofunc, ty, substitutions)) + case Err(e): + return Err(e) + case Err(e): + return Err(e) + raise Exception('Unreachable') elif isinstance(j, str): match type_ctx.instantiate(j): case Some(j_type): diff --git a/ir.py b/ir.py index 809f555..2cb5719 100644 --- a/ir.py +++ b/ir.py @@ -8,8 +8,8 @@ from patterns import Pattern import types_ -Expression: TypeAlias = 'Function | Application | Int | Variable | Builtin | LetBinding | ReplHole | Switch' -Value: TypeAlias = 'Function | Int | Builtin | ReplHole' +Expression: TypeAlias = 'MonoFunc | Application | Int | Variable | Builtin | LetBinding | ReplHole | Switch' +Value: TypeAlias = 'MonoFunc | Int | Builtin | ReplHole' @dataclass(frozen=True) class ReplHole: @@ -91,14 +91,15 @@ BUILTIN_SUBSTITUTIONS: Sequence[Tuple[str, Expression]] = ( ) @dataclass(frozen=True) -class Function: - forms: 'Sequence[Tuple[Pattern, Expression]]' +class MonoFunc: + arg: str + body: Expression def subst(self, expression: Expression, variable: str) -> Expression: - return Function([ - (p, e if p.binds(variable) else e.subst(expression, variable)) - for (p, e) in self.forms - ]) + if variable == self.arg: + return self + else: + return MonoFunc(self.arg, self.body.subst(expression, variable)) def is_value(self) -> bool: return True @@ -106,52 +107,42 @@ class Function: def step(self) -> Option[Expression]: return None - def eliminate(self, v: Expression) -> Result[Expression, MatchException]: - match_trees = tuple(pattern.match_tree(EMPTY_STRUCT_PATH, LeafNode.from_value(bindings_to_lets(pattern.bindings(), v, body))) for (pattern, body) in self.forms) - unified_match_tree = merge_all_trees(match_trees) - return compile_tree(unified_match_tree, v) - - def try_apply(self, v: Expression) -> Option[Expression]: - return hush(self.eliminate(v)) - - def codegen_inner(self) -> str: - return unwrap_r(self.eliminate(Variable('$'))).codegen() - def try_codegen_sp(self) -> Option[Tuple[str, str]]: - """ A special-case of codegen inner (see description) - - In certain cases, starting a function with a full match tree may be unnecessary. - Specifically, if there exists only one possible branch and that branch binds only - one value and that value is equal to the whole entire input, rather than assigning - that input to a new variable, we may simply use argument variable instead. - - This method returns the generated code for the inner branch in such a case. - Additionally, the second string returned represents the name of the variable which - ought to be bound as the argument. If the argument is unused, this will be "$". - """ - match self.forms: - case [(patt, expr)]: # A single possible branch + @staticmethod + def from_match_function(forms: 'Sequence[Tuple[Pattern, Expression]]') -> Result[Expression, MatchException]: + # In certain cases, starting a function with a full match tree may be unnecessary. + # Specifically, if there exists only one possible branch and that branch binds only + # one value and that value is equal to the whole entire input, rather than assigning + # that input to a new variable, we may simply use argument variable instead. + match forms: + case [(patt, body)]: # A single possible branch match patt.bindings(): case []: # Binds nothing - return Some((expr.codegen(), '$')) + return Ok(MonoFunc('_', body)) case [(var, [])]: # Binds a single variable to the entire input - return Some((expr.codegen(), var)) - return None + return Ok(MonoFunc(var, body)) + + # If those special cases fail, we eliminate the pattern matching to produce a + # single body: + match_trees = tuple( # Construct a match tree for each possible branch + pattern.match_tree( + EMPTY_STRUCT_PATH, + LeafNode.from_value(bindings_to_lets(pattern.bindings(), Variable('$'), body)) + ) + for (pattern, body) in forms + ) + unified_match_tree = merge_all_trees(match_trees) # Unify all the trees + compiled_tree = compile_tree(unified_match_tree, Variable('$')) # Turn each tree into IR + return compiled_tree <= p(MonoFunc, '$') + + def try_apply(self, v: Expression) -> Option[Expression]: + return Some(self.body.subst(v, self.arg)) + def codegen(self) -> str: - match self.try_codegen_sp(): - case Some((codegen, var)): - return f'{var}=>{codegen}' - case None: - return '$=>' + self.codegen_inner() - raise Exception('Unreachable') + return f'({self.arg}=>{self.body.codegen()})' def codegen_named(self, name) -> str: - match self.try_codegen_sp(): - case Some((codegen, var)): - return f'function {name}({var}){{return {codegen}}}' - case None: - return f'function {name}($){{return {self.codegen_inner()}}}' - raise Exception('Unreachable') + return f'(function {name}({self.arg}){{return {self.body.codegen()}}})' def __repr__(self) -> str: - return '{ ' + ', '.join('"' + repr(repr(p))[1:-1] + '" : ' + repr(e) for (p, e) in self.forms) + ' }' + return f'{{{repr(self.arg)}: {repr(self.body)}}}' @dataclass class LetBinding: @@ -188,10 +179,10 @@ class LetBinding: ) def __repr__(self) -> str: - return f'( {repr(self.lhs)}, {repr(self.rhs)}, {repr(self.body)} )' + return f'( "{self.lhs}", {repr(self.rhs)}, {repr(self.body)} )' def codegen(self) -> str: - rhs_cg = self.rhs.codegen_named(self.lhs) if isinstance(self.rhs, Function) else self.rhs.codegen() + rhs_cg = self.rhs.codegen_named(self.lhs) if isinstance(self.rhs, MonoFunc) else self.rhs.codegen() return f'({self.lhs}=>{self.body.codegen()})({rhs_cg})' @dataclass @@ -217,7 +208,7 @@ class Application: case Some(arg_stepped): return Some(Application(self.first, arg_stepped)) case None: - assert isinstance(self.first, Function) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed" + assert isinstance(self.first, MonoFunc) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed" return self.first.try_apply(self.arg) raise Exception('Unreachable') @@ -225,7 +216,7 @@ class Application: return f'[ {repr(self.first)}, {repr(self.arg)} ]' def codegen(self) -> str: - if isinstance(self.first, Function | Builtin) and self.arg.is_value(): + if isinstance(self.first, MonoFunc | Builtin) and self.arg.is_value(): return unwrap_opt(self.first.try_apply(self.arg)).codegen() else: match self.first: