Perform match elimination at IR generation rather than codegen

This commit is contained in:
Emi Simpson 2024-03-16 19:42:09 -04:00
parent e6d8933ccf
commit 719de87ea9
Signed by: Emi
GPG Key ID: A12F2C2FFDC3D847
2 changed files with 56 additions and 56 deletions

View File

@ -2,15 +2,16 @@ from emis_funky_funktions import *
from typing import Sequence, Mapping
from pattern import lex_and_parse_pattern
from ir import Expression, Function, Application, Int, Variable, LetBinding, ReplHole
from ir import Expression, MonoFunc, Application, Int, Variable, LetBinding, ReplHole
from patterns import Pattern, NamePattern, IgnorePattern, IntPattern, SPattern
from match_tree import MatchException
from types_ import *
from functools import reduce
import json
JsonType: TypeAlias = 'Mapping[str, JsonType] | Sequence[JsonType] | int | str'
SemanticError: TypeAlias = 'PatternParseProblem | UnificationError | BranchTypesDiffer | UndefinedVariable'
SemanticError: TypeAlias = 'PatternParseProblem | UnificationError | BranchTypesDiffer | UndefinedVariable | MatchException'
@dataclass(frozen=True)
class PatternParseProblem:
@ -170,8 +171,16 @@ def let_2_ir(
def json_to_ir(j: JsonType, type_ctx: Context) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
if isinstance(j, Mapping):
return branches_2_ir(tuple(j.items()), type_ctx) <= (lambda ir_ty_subst:
(Function(ir_ty_subst[0]), *ir_ty_subst[1:]))
match branches_2_ir(tuple(j.items()), type_ctx):
case Ok((branches, ty, substitutions)):
match MonoFunc.from_match_function(branches):
case Ok(monofunc):
return Ok((monofunc, ty, substitutions))
case Err(e):
return Err(e)
case Err(e):
return Err(e)
raise Exception('Unreachable')
elif isinstance(j, str):
match type_ctx.instantiate(j):
case Some(j_type):

95
ir.py
View File

@ -8,8 +8,8 @@ from patterns import Pattern
import types_
Expression: TypeAlias = 'Function | Application | Int | Variable | Builtin | LetBinding | ReplHole | Switch'
Value: TypeAlias = 'Function | Int | Builtin | ReplHole'
Expression: TypeAlias = 'MonoFunc | Application | Int | Variable | Builtin | LetBinding | ReplHole | Switch'
Value: TypeAlias = 'MonoFunc | Int | Builtin | ReplHole'
@dataclass(frozen=True)
class ReplHole:
@ -91,14 +91,15 @@ BUILTIN_SUBSTITUTIONS: Sequence[Tuple[str, Expression]] = (
)
@dataclass(frozen=True)
class Function:
forms: 'Sequence[Tuple[Pattern, Expression]]'
class MonoFunc:
arg: str
body: Expression
def subst(self, expression: Expression, variable: str) -> Expression:
return Function([
(p, e if p.binds(variable) else e.subst(expression, variable))
for (p, e) in self.forms
])
if variable == self.arg:
return self
else:
return MonoFunc(self.arg, self.body.subst(expression, variable))
def is_value(self) -> bool:
return True
@ -106,52 +107,42 @@ class Function:
def step(self) -> Option[Expression]:
return None
def eliminate(self, v: Expression) -> Result[Expression, MatchException]:
match_trees = tuple(pattern.match_tree(EMPTY_STRUCT_PATH, LeafNode.from_value(bindings_to_lets(pattern.bindings(), v, body))) for (pattern, body) in self.forms)
unified_match_tree = merge_all_trees(match_trees)
return compile_tree(unified_match_tree, v)
def try_apply(self, v: Expression) -> Option[Expression]:
return hush(self.eliminate(v))
def codegen_inner(self) -> str:
return unwrap_r(self.eliminate(Variable('$'))).codegen()
def try_codegen_sp(self) -> Option[Tuple[str, str]]:
""" A special-case of codegen inner (see description)
In certain cases, starting a function with a full match tree may be unnecessary.
Specifically, if there exists only one possible branch and that branch binds only
one value and that value is equal to the whole entire input, rather than assigning
that input to a new variable, we may simply use argument variable instead.
This method returns the generated code for the inner branch in such a case.
Additionally, the second string returned represents the name of the variable which
ought to be bound as the argument. If the argument is unused, this will be "$".
"""
match self.forms:
case [(patt, expr)]: # A single possible branch
@staticmethod
def from_match_function(forms: 'Sequence[Tuple[Pattern, Expression]]') -> Result[Expression, MatchException]:
# In certain cases, starting a function with a full match tree may be unnecessary.
# Specifically, if there exists only one possible branch and that branch binds only
# one value and that value is equal to the whole entire input, rather than assigning
# that input to a new variable, we may simply use argument variable instead.
match forms:
case [(patt, body)]: # A single possible branch
match patt.bindings():
case []: # Binds nothing
return Some((expr.codegen(), '$'))
return Ok(MonoFunc('_', body))
case [(var, [])]: # Binds a single variable to the entire input
return Some((expr.codegen(), var))
return None
return Ok(MonoFunc(var, body))
# If those special cases fail, we eliminate the pattern matching to produce a
# single body:
match_trees = tuple( # Construct a match tree for each possible branch
pattern.match_tree(
EMPTY_STRUCT_PATH,
LeafNode.from_value(bindings_to_lets(pattern.bindings(), Variable('$'), body))
)
for (pattern, body) in forms
)
unified_match_tree = merge_all_trees(match_trees) # Unify all the trees
compiled_tree = compile_tree(unified_match_tree, Variable('$')) # Turn each tree into IR
return compiled_tree <= p(MonoFunc, '$')
def try_apply(self, v: Expression) -> Option[Expression]:
return Some(self.body.subst(v, self.arg))
def codegen(self) -> str:
match self.try_codegen_sp():
case Some((codegen, var)):
return f'{var}=>{codegen}'
case None:
return '$=>' + self.codegen_inner()
raise Exception('Unreachable')
return f'({self.arg}=>{self.body.codegen()})'
def codegen_named(self, name) -> str:
match self.try_codegen_sp():
case Some((codegen, var)):
return f'function {name}({var}){{return {codegen}}}'
case None:
return f'function {name}($){{return {self.codegen_inner()}}}'
raise Exception('Unreachable')
return f'(function {name}({self.arg}){{return {self.body.codegen()}}})'
def __repr__(self) -> str:
return '{ ' + ', '.join('"' + repr(repr(p))[1:-1] + '" : ' + repr(e) for (p, e) in self.forms) + ' }'
return f'{{{repr(self.arg)}: {repr(self.body)}}}'
@dataclass
class LetBinding:
@ -188,10 +179,10 @@ class LetBinding:
)
def __repr__(self) -> str:
return f'( {repr(self.lhs)}, {repr(self.rhs)}, {repr(self.body)} )'
return f'( "{self.lhs}", {repr(self.rhs)}, {repr(self.body)} )'
def codegen(self) -> str:
rhs_cg = self.rhs.codegen_named(self.lhs) if isinstance(self.rhs, Function) else self.rhs.codegen()
rhs_cg = self.rhs.codegen_named(self.lhs) if isinstance(self.rhs, MonoFunc) else self.rhs.codegen()
return f'({self.lhs}=>{self.body.codegen()})({rhs_cg})'
@dataclass
@ -217,7 +208,7 @@ class Application:
case Some(arg_stepped):
return Some(Application(self.first, arg_stepped))
case None:
assert isinstance(self.first, Function) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed"
assert isinstance(self.first, MonoFunc) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed"
return self.first.try_apply(self.arg)
raise Exception('Unreachable')
@ -225,7 +216,7 @@ class Application:
return f'[ {repr(self.first)}, {repr(self.arg)} ]'
def codegen(self) -> str:
if isinstance(self.first, Function | Builtin) and self.arg.is_value():
if isinstance(self.first, MonoFunc | Builtin) and self.arg.is_value():
return unwrap_opt(self.first.try_apply(self.arg)).codegen()
else:
match self.first: