Optimization: Remove redundant let expressions

This commit is contained in:
Emi Simpson 2024-03-16 22:02:22 -04:00
parent 719de87ea9
commit 6f89dfa1c5
Signed by: Emi
GPG key ID: A12F2C2FFDC3D847
3 changed files with 107 additions and 2 deletions

View file

@ -5,6 +5,7 @@ from ir import BUILTIN_SUBSTITUTIONS, Expression, ReplHole, subst_all
from genir import json_to_ir, PatternParseProblem, BranchTypesDiffer, UndefinedVariable
from types_ import BUILTINS_CONTEXT, UnificationError
from silly_thing import evaluate
from opt import optimize_to_fixpoint, all_optimizations
import json
from dataclasses import dataclass
@ -16,7 +17,9 @@ def main():
case [_, file]:
# TODO handle this
expr, ty, substs = unwrap_r(json_to_ir(json.loads(open(sys.argv[1]).read()), BUILTINS_CONTEXT))
result = evaluate(subst_all(BUILTIN_SUBSTITUTIONS, expr))
expr_with_builtins = subst_all(BUILTIN_SUBSTITUTIONS, expr)
expr_with_optimization = optimize_to_fixpoint(all_optimizations, expr_with_builtins)
result = evaluate(expr_with_optimization)
if isinstance(result, ReplHole):
print(result.render())
else:

24
ir.py
View file

@ -365,4 +365,26 @@ def subst_all(bindings: Sequence[Tuple[str, Expression]], body: Expression) -> E
return body
case [(var, replacement), *rest]:
return subst_all(rest, body.subst(replacement, var))
raise Exception('Unreachable')
raise Exception('Unreachable')
def count_uses(variable: str, expression: Expression) -> int:
match expression:
case MonoFunc(arg, body):
return 0 if arg == variable else count_uses(variable, body)
case Application(first, arg):
return count_uses(variable, first) + count_uses(variable, arg)
case Int(_):
return 0
case Variable(name):
return 1 if name == variable else 0
case Builtin(_, _):
return 0
case LetBinding(lhs, rhs, body):
return count_uses(variable, rhs) + count_uses(variable, body)
case ReplHole(_, _):
return 0
case Switch(branches, fallback, switching_on):
return (
count_uses(variable, switching_on) +
count_uses(variable, fallback) +
sum(count_uses(variable, branch) for branch in branches.values()))

80
opt.py Normal file
View file

@ -0,0 +1,80 @@
from emis_funky_funktions import *
from ir import Expression, LetBinding, count_uses, MonoFunc, Application, LetBinding, Switch, Int, Variable, Builtin, ReplHole
from typing import TypeAlias, Collection
from functools import reduce
Optimization: TypeAlias = Callable[[Expression], Option[Expression]]
def eliminate_single_let(expr: Expression) -> Option[Expression]:
match expr:
case LetBinding(lhs, rhs, body):
if count_uses(lhs, body) <= 1:
# RHS is used at most once i nthe body
if count_uses(lhs, rhs) > 0:
# RHS is recursive
if not isinstance(body, Variable):
# But can still be pushed down
return Some(body.subst(LetBinding(lhs, rhs, Variable(lhs)), lhs))
else:
# And is already maximally pushed down
return None
else:
# RHS is not recursive
return Some(body.subst(rhs, lhs))
else:
# RHS is used multiple times in the body
return None
case _:
return None
def apply_opts(optimizations: Collection[Optimization], expression: Expression) -> Tuple[Expression, int]:
count: int
optimized_expr: Expression
match expression:
case MonoFunc(arg, body):
(optimized_body, count) = apply_opts(optimizations, body)
optimized_expr = MonoFunc(arg, optimized_body)
case Application(first, arg):
(optimized_first, count_first) = apply_opts(optimizations, first)
(optimized_arg , count_arg ) = apply_opts(optimizations, arg)
optimized_expr = Application(first, arg)
count = count_first + count_arg
case LetBinding(lhs, rhs, body):
(optimized_rhs , count_rhs ) = apply_opts(optimizations, rhs)
(optimized_body , count_body ) = apply_opts(optimizations, body)
optimized_expr = LetBinding(lhs, optimized_rhs, optimized_body)
count = count_rhs + count_body
case Switch(branches, fallback, switching_on):
branch_optimizations = tuple(
(branch_num, apply_opts(optimizations, branch_body))
for branch_num, branch_body in branches.items())
optimized_fallback, count_fallback = apply_opts(optimizations, fallback)
optimized_switching_on, count_switching_on = apply_opts(optimizations, switching_on)
count = sum(count_branch for _, (_, count_branch) in branch_optimizations) + count_fallback + count_switching_on
optimized_branches = {branch_num: optimized_branch for branch_num, (optimized_branch, _) in branch_optimizations}
optimized_expr = Switch(optimized_branches, optimized_fallback, optimized_switching_on)
case Int(_) | Variable(_) | Builtin(_, _) | ReplHole(_, _) as optimized_expr:
count = 0
def fold_optimizations(acc: Tuple[Expression, int], optimization: Optimization) -> Tuple[Expression, int]:
acc_expression, acc_count = acc
match optimization(acc_expression):
case Some(expr_post_opt):
return (expr_post_opt, acc_count + 1)
case None:
return (acc_expression, acc_count)
raise Exception('Unreachable')
return reduce(fold_optimizations, optimizations, (optimized_expr, 0))
def optimize_to_fixpoint(optimizations: Collection[Optimization], expression: Expression) -> Expression:
match apply_opts(optimizations, expression):
case (optimal_expression, 0):
return optimal_expression
case (optimized_expression, _):
return optimize_to_fixpoint(optimizations, optimized_expression)
raise Exception('Unreachable')
all_optimizations: Sequence[Optimization] = (eliminate_single_let,)