Compare commits

...

6 Commits

4 changed files with 168 additions and 30 deletions

View File

@ -5,6 +5,7 @@ from ir import BUILTIN_SUBSTITUTIONS, Expression, ReplHole, subst_all
from genir import json_to_ir, PatternParseProblem, BranchTypesDiffer, UndefinedVariable
from types_ import BUILTINS_CONTEXT, UnificationError
from silly_thing import evaluate
from opt import optimize_to_fixpoint, all_optimizations
import json
from dataclasses import dataclass
@ -16,7 +17,9 @@ def main():
case [_, file]:
# TODO handle this
expr, ty, substs = unwrap_r(json_to_ir(json.loads(open(sys.argv[1]).read()), BUILTINS_CONTEXT))
result = evaluate(subst_all(BUILTIN_SUBSTITUTIONS, expr))
expr_with_builtins = subst_all(BUILTIN_SUBSTITUTIONS, expr)
expr_with_optimization = optimize_to_fixpoint(all_optimizations, expr_with_builtins)
result = evaluate(expr_with_optimization)
if isinstance(result, ReplHole):
print(result.render())
else:

View File

@ -2,15 +2,16 @@ from emis_funky_funktions import *
from typing import Sequence, Mapping
from pattern import lex_and_parse_pattern
from ir import Expression, Function, Application, Int, Variable, LetBinding, ReplHole
from ir import Expression, MonoFunc, Application, Int, Variable, LetBinding, ReplHole
from patterns import Pattern, NamePattern, IgnorePattern, IntPattern, SPattern
from match_tree import MatchException
from types_ import *
from functools import reduce
import json
JsonType: TypeAlias = 'Mapping[str, JsonType] | Sequence[JsonType] | int | str'
SemanticError: TypeAlias = 'PatternParseProblem | UnificationError | BranchTypesDiffer | UndefinedVariable'
SemanticError: TypeAlias = 'PatternParseProblem | UnificationError | BranchTypesDiffer | UndefinedVariable | MatchException'
@dataclass(frozen=True)
class PatternParseProblem:
@ -170,8 +171,16 @@ def let_2_ir(
def json_to_ir(j: JsonType, type_ctx: Context) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
if isinstance(j, Mapping):
return branches_2_ir(tuple(j.items()), type_ctx) <= (lambda ir_ty_subst:
(Function(ir_ty_subst[0]), *ir_ty_subst[1:]))
match branches_2_ir(tuple(j.items()), type_ctx):
case Ok((branches, ty, substitutions)):
match MonoFunc.from_match_function(branches):
case Ok(monofunc):
return Ok((monofunc, ty, substitutions))
case Err(e):
return Err(e)
case Err(e):
return Err(e)
raise Exception('Unreachable')
elif isinstance(j, str):
match type_ctx.instantiate(j):
case Some(j_type):

96
ir.py
View File

@ -8,8 +8,8 @@ from patterns import Pattern
import types_
Expression: TypeAlias = 'Function | Application | Int | Variable | Builtin | LetBinding | ReplHole | Switch'
Value: TypeAlias = 'Function | Int | Builtin | ReplHole'
Expression: TypeAlias = 'MonoFunc | Application | Int | Variable | Builtin | LetBinding | ReplHole | Switch'
Value: TypeAlias = 'MonoFunc | Int | Builtin | ReplHole'
@dataclass(frozen=True)
class ReplHole:
@ -91,14 +91,15 @@ BUILTIN_SUBSTITUTIONS: Sequence[Tuple[str, Expression]] = (
)
@dataclass(frozen=True)
class Function:
forms: 'Sequence[Tuple[Pattern, Expression]]'
class MonoFunc:
arg: str
body: Expression
def subst(self, expression: Expression, variable: str) -> Expression:
return Function([
(p, e if p.binds(variable) else e.subst(expression, variable))
for (p, e) in self.forms
])
if variable == self.arg:
return self
else:
return MonoFunc(self.arg, self.body.subst(expression, variable))
def is_value(self) -> bool:
return True
@ -106,22 +107,42 @@ class Function:
def step(self) -> Option[Expression]:
return None
def eliminate(self, v: Expression) -> Result[Expression, MatchException]:
match_trees = tuple(pattern.match_tree(EMPTY_STRUCT_PATH, LeafNode.from_value(bindings_to_lets(pattern.bindings(), v, body))) for (pattern, body) in self.forms)
unified_match_tree = merge_all_trees(match_trees)
return compile_tree(unified_match_tree, v)
@staticmethod
def from_match_function(forms: 'Sequence[Tuple[Pattern, Expression]]') -> Result[Expression, MatchException]:
# In certain cases, starting a function with a full match tree may be unnecessary.
# Specifically, if there exists only one possible branch and that branch binds only
# one value and that value is equal to the whole entire input, rather than assigning
# that input to a new variable, we may simply use argument variable instead.
match forms:
case [(patt, body)]: # A single possible branch
match patt.bindings():
case []: # Binds nothing
return Ok(MonoFunc('_', body))
case [(var, [])]: # Binds a single variable to the entire input
return Ok(MonoFunc(var, body))
# If those special cases fail, we eliminate the pattern matching to produce a
# single body:
match_trees = tuple( # Construct a match tree for each possible branch
pattern.match_tree(
EMPTY_STRUCT_PATH,
LeafNode.from_value(bindings_to_lets(pattern.bindings(), Variable('$'), body))
)
for (pattern, body) in forms
)
unified_match_tree = merge_all_trees(match_trees) # Unify all the trees
compiled_tree = compile_tree(unified_match_tree, Variable('$')) # Turn each tree into IR
return compiled_tree <= p(MonoFunc, '$')
def try_apply(self, v: Expression) -> Option[Expression]:
return hush(self.eliminate(v))
return Some(self.body.subst(v, self.arg))
def codegen_inner(self) -> str:
return unwrap_r(self.eliminate(Variable('$'))).codegen()
def codegen(self) -> str:
return '$=>' + self.codegen_inner()
return f'({self.arg}=>{self.body.codegen()})'
def codegen_named(self, name) -> str:
return f'function {name}($){{return {self.codegen_inner()}}}'
return f'(function {name}({self.arg}){{return {self.body.codegen()}}})'
def __repr__(self) -> str:
return '{ ' + ', '.join('"' + repr(repr(p))[1:-1] + '" : ' + repr(e) for (p, e) in self.forms) + ' }'
return f'{{{repr(self.arg)}: {repr(self.body)}}}'
@dataclass
class LetBinding:
@ -158,11 +179,14 @@ class LetBinding:
)
def __repr__(self) -> str:
return f'( {repr(self.lhs)}, {repr(self.rhs)}, {repr(self.body)} )'
return f'( "{self.lhs}", {repr(self.rhs)}, {repr(self.body)} )'
def codegen(self) -> str:
rhs_cg = self.rhs.codegen_named(self.lhs) if isinstance(self.rhs, Function) else self.rhs.codegen()
return f'({self.lhs}={rhs_cg},{self.body.codegen()})'
rhs_cg = self.rhs.codegen_named(self.lhs) if isinstance(self.rhs, MonoFunc) else self.rhs.codegen()
if self.body == Variable(self.lhs):
return rhs_cg
else:
return f'({self.lhs}=>{self.body.codegen()})({rhs_cg})'
@dataclass
class Application:
@ -187,7 +211,7 @@ class Application:
case Some(arg_stepped):
return Some(Application(self.first, arg_stepped))
case None:
assert isinstance(self.first, Function) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed"
assert isinstance(self.first, MonoFunc) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed"
return self.first.try_apply(self.arg)
raise Exception('Unreachable')
@ -195,7 +219,7 @@ class Application:
return f'[ {repr(self.first)}, {repr(self.arg)} ]'
def codegen(self) -> str:
if isinstance(self.first, Function | Builtin) and self.arg.is_value():
if isinstance(self.first, MonoFunc | Builtin) and self.arg.is_value():
return unwrap_opt(self.first.try_apply(self.arg)).codegen()
else:
match self.first:
@ -262,7 +286,7 @@ class Switch:
def subst(self, expression: Expression, variable: str) -> Expression:
return Switch(
{i: e.subst(expression, variable) for i, e in self.branches.items()},
self.fallback,
self.fallback.subst(expression, variable),
self.switching_on.subst(expression, variable))
def is_value(self) -> bool:
@ -344,4 +368,26 @@ def subst_all(bindings: Sequence[Tuple[str, Expression]], body: Expression) -> E
return body
case [(var, replacement), *rest]:
return subst_all(rest, body.subst(replacement, var))
raise Exception('Unreachable')
raise Exception('Unreachable')
def count_uses(variable: str, expression: Expression) -> int:
match expression:
case MonoFunc(arg, body):
return 0 if arg == variable else count_uses(variable, body)
case Application(first, arg):
return count_uses(variable, first) + count_uses(variable, arg)
case Int(_):
return 0
case Variable(name):
return 1 if name == variable else 0
case Builtin(_, _):
return 0
case LetBinding(lhs, rhs, body):
return count_uses(variable, rhs) + count_uses(variable, body)
case ReplHole(_, _):
return 0
case Switch(branches, fallback, switching_on):
return (
count_uses(variable, switching_on) +
count_uses(variable, fallback) +
sum(count_uses(variable, branch) for branch in branches.values()))

80
opt.py Normal file
View File

@ -0,0 +1,80 @@
from emis_funky_funktions import *
from ir import Expression, LetBinding, count_uses, MonoFunc, Application, LetBinding, Switch, Int, Variable, Builtin, ReplHole
from typing import TypeAlias, Collection
from functools import reduce
Optimization: TypeAlias = Callable[[Expression], Option[Expression]]
def eliminate_single_let(expr: Expression) -> Option[Expression]:
match expr:
case LetBinding(lhs, rhs, body):
if count_uses(lhs, body) <= 1:
# RHS is used at most once i nthe body
if count_uses(lhs, rhs) > 0:
# RHS is recursive
if not isinstance(body, Variable):
# But can still be pushed down
return Some(body.subst(LetBinding(lhs, rhs, Variable(lhs)), lhs))
else:
# And is already maximally pushed down
return None
else:
# RHS is not recursive
return Some(body.subst(rhs, lhs))
else:
# RHS is used multiple times in the body
return None
case _:
return None
def apply_opts(optimizations: Collection[Optimization], expression: Expression) -> Tuple[Expression, int]:
count: int
optimized_expr: Expression
match expression:
case MonoFunc(arg, body):
(optimized_body, count) = apply_opts(optimizations, body)
optimized_expr = MonoFunc(arg, optimized_body)
case Application(first, arg):
(optimized_first, count_first) = apply_opts(optimizations, first)
(optimized_arg , count_arg ) = apply_opts(optimizations, arg)
optimized_expr = Application(first, arg)
count = count_first + count_arg
case LetBinding(lhs, rhs, body):
(optimized_rhs , count_rhs ) = apply_opts(optimizations, rhs)
(optimized_body , count_body ) = apply_opts(optimizations, body)
optimized_expr = LetBinding(lhs, optimized_rhs, optimized_body)
count = count_rhs + count_body
case Switch(branches, fallback, switching_on):
branch_optimizations = tuple(
(branch_num, apply_opts(optimizations, branch_body))
for branch_num, branch_body in branches.items())
optimized_fallback, count_fallback = apply_opts(optimizations, fallback)
optimized_switching_on, count_switching_on = apply_opts(optimizations, switching_on)
count = sum(count_branch for _, (_, count_branch) in branch_optimizations) + count_fallback + count_switching_on
optimized_branches = {branch_num: optimized_branch for branch_num, (optimized_branch, _) in branch_optimizations}
optimized_expr = Switch(optimized_branches, optimized_fallback, optimized_switching_on)
case Int(_) | Variable(_) | Builtin(_, _) | ReplHole(_, _) as optimized_expr:
count = 0
def fold_optimizations(acc: Tuple[Expression, int], optimization: Optimization) -> Tuple[Expression, int]:
acc_expression, acc_count = acc
match optimization(acc_expression):
case Some(expr_post_opt):
return (expr_post_opt, acc_count + 1)
case None:
return (acc_expression, acc_count)
raise Exception('Unreachable')
return reduce(fold_optimizations, optimizations, (optimized_expr, 0))
def optimize_to_fixpoint(optimizations: Collection[Optimization], expression: Expression) -> Expression:
match apply_opts(optimizations, expression):
case (optimal_expression, 0):
return optimal_expression
case (optimized_expression, _):
return optimize_to_fixpoint(optimizations, optimized_expression)
raise Exception('Unreachable')
all_optimizations: Sequence[Optimization] = (eliminate_single_let,)