From 6f89dfa1c51f247d552c484348afa20f25783bb4 Mon Sep 17 00:00:00 2001 From: Emi Simpson Date: Sat, 16 Mar 2024 22:02:22 -0400 Subject: [PATCH] Optimization: Remove redundant let expressions --- compile.py | 5 +++- ir.py | 24 +++++++++++++++- opt.py | 80 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 107 insertions(+), 2 deletions(-) create mode 100644 opt.py diff --git a/compile.py b/compile.py index a3e7b59..5966939 100644 --- a/compile.py +++ b/compile.py @@ -5,6 +5,7 @@ from ir import BUILTIN_SUBSTITUTIONS, Expression, ReplHole, subst_all from genir import json_to_ir, PatternParseProblem, BranchTypesDiffer, UndefinedVariable from types_ import BUILTINS_CONTEXT, UnificationError from silly_thing import evaluate +from opt import optimize_to_fixpoint, all_optimizations import json from dataclasses import dataclass @@ -16,7 +17,9 @@ def main(): case [_, file]: # TODO handle this expr, ty, substs = unwrap_r(json_to_ir(json.loads(open(sys.argv[1]).read()), BUILTINS_CONTEXT)) - result = evaluate(subst_all(BUILTIN_SUBSTITUTIONS, expr)) + expr_with_builtins = subst_all(BUILTIN_SUBSTITUTIONS, expr) + expr_with_optimization = optimize_to_fixpoint(all_optimizations, expr_with_builtins) + result = evaluate(expr_with_optimization) if isinstance(result, ReplHole): print(result.render()) else: diff --git a/ir.py b/ir.py index 2cb5719..48797a9 100644 --- a/ir.py +++ b/ir.py @@ -365,4 +365,26 @@ def subst_all(bindings: Sequence[Tuple[str, Expression]], body: Expression) -> E return body case [(var, replacement), *rest]: return subst_all(rest, body.subst(replacement, var)) - raise Exception('Unreachable') \ No newline at end of file + raise Exception('Unreachable') + +def count_uses(variable: str, expression: Expression) -> int: + match expression: + case MonoFunc(arg, body): + return 0 if arg == variable else count_uses(variable, body) + case Application(first, arg): + return count_uses(variable, first) + count_uses(variable, arg) + case Int(_): + return 0 + case Variable(name): + return 1 if name == variable else 0 + case Builtin(_, _): + return 0 + case LetBinding(lhs, rhs, body): + return count_uses(variable, rhs) + count_uses(variable, body) + case ReplHole(_, _): + return 0 + case Switch(branches, fallback, switching_on): + return ( + count_uses(variable, switching_on) + + count_uses(variable, fallback) + + sum(count_uses(variable, branch) for branch in branches.values())) \ No newline at end of file diff --git a/opt.py b/opt.py new file mode 100644 index 0000000..f727616 --- /dev/null +++ b/opt.py @@ -0,0 +1,80 @@ +from emis_funky_funktions import * + +from ir import Expression, LetBinding, count_uses, MonoFunc, Application, LetBinding, Switch, Int, Variable, Builtin, ReplHole + +from typing import TypeAlias, Collection +from functools import reduce + +Optimization: TypeAlias = Callable[[Expression], Option[Expression]] + +def eliminate_single_let(expr: Expression) -> Option[Expression]: + match expr: + case LetBinding(lhs, rhs, body): + if count_uses(lhs, body) <= 1: + # RHS is used at most once i nthe body + if count_uses(lhs, rhs) > 0: + # RHS is recursive + if not isinstance(body, Variable): + # But can still be pushed down + return Some(body.subst(LetBinding(lhs, rhs, Variable(lhs)), lhs)) + else: + # And is already maximally pushed down + return None + else: + # RHS is not recursive + return Some(body.subst(rhs, lhs)) + else: + # RHS is used multiple times in the body + return None + case _: + return None + +def apply_opts(optimizations: Collection[Optimization], expression: Expression) -> Tuple[Expression, int]: + count: int + optimized_expr: Expression + match expression: + case MonoFunc(arg, body): + (optimized_body, count) = apply_opts(optimizations, body) + optimized_expr = MonoFunc(arg, optimized_body) + case Application(first, arg): + (optimized_first, count_first) = apply_opts(optimizations, first) + (optimized_arg , count_arg ) = apply_opts(optimizations, arg) + optimized_expr = Application(first, arg) + count = count_first + count_arg + case LetBinding(lhs, rhs, body): + (optimized_rhs , count_rhs ) = apply_opts(optimizations, rhs) + (optimized_body , count_body ) = apply_opts(optimizations, body) + optimized_expr = LetBinding(lhs, optimized_rhs, optimized_body) + count = count_rhs + count_body + case Switch(branches, fallback, switching_on): + branch_optimizations = tuple( + (branch_num, apply_opts(optimizations, branch_body)) + for branch_num, branch_body in branches.items()) + optimized_fallback, count_fallback = apply_opts(optimizations, fallback) + optimized_switching_on, count_switching_on = apply_opts(optimizations, switching_on) + count = sum(count_branch for _, (_, count_branch) in branch_optimizations) + count_fallback + count_switching_on + optimized_branches = {branch_num: optimized_branch for branch_num, (optimized_branch, _) in branch_optimizations} + optimized_expr = Switch(optimized_branches, optimized_fallback, optimized_switching_on) + case Int(_) | Variable(_) | Builtin(_, _) | ReplHole(_, _) as optimized_expr: + count = 0 + + def fold_optimizations(acc: Tuple[Expression, int], optimization: Optimization) -> Tuple[Expression, int]: + acc_expression, acc_count = acc + match optimization(acc_expression): + case Some(expr_post_opt): + return (expr_post_opt, acc_count + 1) + case None: + return (acc_expression, acc_count) + raise Exception('Unreachable') + + return reduce(fold_optimizations, optimizations, (optimized_expr, 0)) + +def optimize_to_fixpoint(optimizations: Collection[Optimization], expression: Expression) -> Expression: + match apply_opts(optimizations, expression): + case (optimal_expression, 0): + return optimal_expression + case (optimized_expression, _): + return optimize_to_fixpoint(optimizations, optimized_expression) + raise Exception('Unreachable') + +all_optimizations: Sequence[Optimization] = (eliminate_single_let,) \ No newline at end of file