Perform match elimination at IR generation rather than codegen

This commit is contained in:
Emi Simpson 2024-03-16 19:42:09 -04:00
parent e6d8933ccf
commit 719de87ea9
Signed by: Emi
GPG Key ID: A12F2C2FFDC3D847
2 changed files with 56 additions and 56 deletions

View File

@ -2,15 +2,16 @@ from emis_funky_funktions import *
from typing import Sequence, Mapping from typing import Sequence, Mapping
from pattern import lex_and_parse_pattern from pattern import lex_and_parse_pattern
from ir import Expression, Function, Application, Int, Variable, LetBinding, ReplHole from ir import Expression, MonoFunc, Application, Int, Variable, LetBinding, ReplHole
from patterns import Pattern, NamePattern, IgnorePattern, IntPattern, SPattern from patterns import Pattern, NamePattern, IgnorePattern, IntPattern, SPattern
from match_tree import MatchException
from types_ import * from types_ import *
from functools import reduce from functools import reduce
import json import json
JsonType: TypeAlias = 'Mapping[str, JsonType] | Sequence[JsonType] | int | str' JsonType: TypeAlias = 'Mapping[str, JsonType] | Sequence[JsonType] | int | str'
SemanticError: TypeAlias = 'PatternParseProblem | UnificationError | BranchTypesDiffer | UndefinedVariable' SemanticError: TypeAlias = 'PatternParseProblem | UnificationError | BranchTypesDiffer | UndefinedVariable | MatchException'
@dataclass(frozen=True) @dataclass(frozen=True)
class PatternParseProblem: class PatternParseProblem:
@ -170,8 +171,16 @@ def let_2_ir(
def json_to_ir(j: JsonType, type_ctx: Context) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]: def json_to_ir(j: JsonType, type_ctx: Context) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
if isinstance(j, Mapping): if isinstance(j, Mapping):
return branches_2_ir(tuple(j.items()), type_ctx) <= (lambda ir_ty_subst: match branches_2_ir(tuple(j.items()), type_ctx):
(Function(ir_ty_subst[0]), *ir_ty_subst[1:])) case Ok((branches, ty, substitutions)):
match MonoFunc.from_match_function(branches):
case Ok(monofunc):
return Ok((monofunc, ty, substitutions))
case Err(e):
return Err(e)
case Err(e):
return Err(e)
raise Exception('Unreachable')
elif isinstance(j, str): elif isinstance(j, str):
match type_ctx.instantiate(j): match type_ctx.instantiate(j):
case Some(j_type): case Some(j_type):

95
ir.py
View File

@ -8,8 +8,8 @@ from patterns import Pattern
import types_ import types_
Expression: TypeAlias = 'Function | Application | Int | Variable | Builtin | LetBinding | ReplHole | Switch' Expression: TypeAlias = 'MonoFunc | Application | Int | Variable | Builtin | LetBinding | ReplHole | Switch'
Value: TypeAlias = 'Function | Int | Builtin | ReplHole' Value: TypeAlias = 'MonoFunc | Int | Builtin | ReplHole'
@dataclass(frozen=True) @dataclass(frozen=True)
class ReplHole: class ReplHole:
@ -91,14 +91,15 @@ BUILTIN_SUBSTITUTIONS: Sequence[Tuple[str, Expression]] = (
) )
@dataclass(frozen=True) @dataclass(frozen=True)
class Function: class MonoFunc:
forms: 'Sequence[Tuple[Pattern, Expression]]' arg: str
body: Expression
def subst(self, expression: Expression, variable: str) -> Expression: def subst(self, expression: Expression, variable: str) -> Expression:
return Function([ if variable == self.arg:
(p, e if p.binds(variable) else e.subst(expression, variable)) return self
for (p, e) in self.forms else:
]) return MonoFunc(self.arg, self.body.subst(expression, variable))
def is_value(self) -> bool: def is_value(self) -> bool:
return True return True
@ -106,52 +107,42 @@ class Function:
def step(self) -> Option[Expression]: def step(self) -> Option[Expression]:
return None return None
def eliminate(self, v: Expression) -> Result[Expression, MatchException]: @staticmethod
match_trees = tuple(pattern.match_tree(EMPTY_STRUCT_PATH, LeafNode.from_value(bindings_to_lets(pattern.bindings(), v, body))) for (pattern, body) in self.forms) def from_match_function(forms: 'Sequence[Tuple[Pattern, Expression]]') -> Result[Expression, MatchException]:
unified_match_tree = merge_all_trees(match_trees) # In certain cases, starting a function with a full match tree may be unnecessary.
return compile_tree(unified_match_tree, v) # Specifically, if there exists only one possible branch and that branch binds only
# one value and that value is equal to the whole entire input, rather than assigning
def try_apply(self, v: Expression) -> Option[Expression]: # that input to a new variable, we may simply use argument variable instead.
return hush(self.eliminate(v)) match forms:
case [(patt, body)]: # A single possible branch
def codegen_inner(self) -> str:
return unwrap_r(self.eliminate(Variable('$'))).codegen()
def try_codegen_sp(self) -> Option[Tuple[str, str]]:
""" A special-case of codegen inner (see description)
In certain cases, starting a function with a full match tree may be unnecessary.
Specifically, if there exists only one possible branch and that branch binds only
one value and that value is equal to the whole entire input, rather than assigning
that input to a new variable, we may simply use argument variable instead.
This method returns the generated code for the inner branch in such a case.
Additionally, the second string returned represents the name of the variable which
ought to be bound as the argument. If the argument is unused, this will be "$".
"""
match self.forms:
case [(patt, expr)]: # A single possible branch
match patt.bindings(): match patt.bindings():
case []: # Binds nothing case []: # Binds nothing
return Some((expr.codegen(), '$')) return Ok(MonoFunc('_', body))
case [(var, [])]: # Binds a single variable to the entire input case [(var, [])]: # Binds a single variable to the entire input
return Some((expr.codegen(), var)) return Ok(MonoFunc(var, body))
return None
# If those special cases fail, we eliminate the pattern matching to produce a
# single body:
match_trees = tuple( # Construct a match tree for each possible branch
pattern.match_tree(
EMPTY_STRUCT_PATH,
LeafNode.from_value(bindings_to_lets(pattern.bindings(), Variable('$'), body))
)
for (pattern, body) in forms
)
unified_match_tree = merge_all_trees(match_trees) # Unify all the trees
compiled_tree = compile_tree(unified_match_tree, Variable('$')) # Turn each tree into IR
return compiled_tree <= p(MonoFunc, '$')
def try_apply(self, v: Expression) -> Option[Expression]:
return Some(self.body.subst(v, self.arg))
def codegen(self) -> str: def codegen(self) -> str:
match self.try_codegen_sp(): return f'({self.arg}=>{self.body.codegen()})'
case Some((codegen, var)):
return f'{var}=>{codegen}'
case None:
return '$=>' + self.codegen_inner()
raise Exception('Unreachable')
def codegen_named(self, name) -> str: def codegen_named(self, name) -> str:
match self.try_codegen_sp(): return f'(function {name}({self.arg}){{return {self.body.codegen()}}})'
case Some((codegen, var)):
return f'function {name}({var}){{return {codegen}}}'
case None:
return f'function {name}($){{return {self.codegen_inner()}}}'
raise Exception('Unreachable')
def __repr__(self) -> str: def __repr__(self) -> str:
return '{ ' + ', '.join('"' + repr(repr(p))[1:-1] + '" : ' + repr(e) for (p, e) in self.forms) + ' }' return f'{{{repr(self.arg)}: {repr(self.body)}}}'
@dataclass @dataclass
class LetBinding: class LetBinding:
@ -188,10 +179,10 @@ class LetBinding:
) )
def __repr__(self) -> str: def __repr__(self) -> str:
return f'( {repr(self.lhs)}, {repr(self.rhs)}, {repr(self.body)} )' return f'( "{self.lhs}", {repr(self.rhs)}, {repr(self.body)} )'
def codegen(self) -> str: def codegen(self) -> str:
rhs_cg = self.rhs.codegen_named(self.lhs) if isinstance(self.rhs, Function) else self.rhs.codegen() rhs_cg = self.rhs.codegen_named(self.lhs) if isinstance(self.rhs, MonoFunc) else self.rhs.codegen()
return f'({self.lhs}=>{self.body.codegen()})({rhs_cg})' return f'({self.lhs}=>{self.body.codegen()})({rhs_cg})'
@dataclass @dataclass
@ -217,7 +208,7 @@ class Application:
case Some(arg_stepped): case Some(arg_stepped):
return Some(Application(self.first, arg_stepped)) return Some(Application(self.first, arg_stepped))
case None: case None:
assert isinstance(self.first, Function) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed" assert isinstance(self.first, MonoFunc) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed"
return self.first.try_apply(self.arg) return self.first.try_apply(self.arg)
raise Exception('Unreachable') raise Exception('Unreachable')
@ -225,7 +216,7 @@ class Application:
return f'[ {repr(self.first)}, {repr(self.arg)} ]' return f'[ {repr(self.first)}, {repr(self.arg)} ]'
def codegen(self) -> str: def codegen(self) -> str:
if isinstance(self.first, Function | Builtin) and self.arg.is_value(): if isinstance(self.first, MonoFunc | Builtin) and self.arg.is_value():
return unwrap_opt(self.first.try_apply(self.arg)).codegen() return unwrap_opt(self.first.try_apply(self.arg)).codegen()
else: else:
match self.first: match self.first: