2024-03-17 02:02:22 +00:00
|
|
|
from emis_funky_funktions import *
|
|
|
|
|
|
|
|
from ir import Expression, LetBinding, count_uses, MonoFunc, Application, LetBinding, Switch, Int, Variable, Builtin, ReplHole
|
|
|
|
|
|
|
|
from typing import TypeAlias, Collection
|
|
|
|
from functools import reduce
|
|
|
|
|
|
|
|
Optimization: TypeAlias = Callable[[Expression], Option[Expression]]
|
|
|
|
|
|
|
|
def eliminate_single_let(expr: Expression) -> Option[Expression]:
|
|
|
|
match expr:
|
|
|
|
case LetBinding(lhs, rhs, body):
|
2024-03-17 14:19:50 +00:00
|
|
|
rhs_is_simple = isinstance(rhs, Int | Variable | Builtin)
|
|
|
|
if count_uses(lhs, body) <= 1 or rhs_is_simple:
|
|
|
|
# RHS is used at most once i nthe body or replication wouldnt be costly
|
2024-03-17 02:02:22 +00:00
|
|
|
if count_uses(lhs, rhs) > 0:
|
|
|
|
# RHS is recursive
|
|
|
|
if not isinstance(body, Variable):
|
|
|
|
# But can still be pushed down
|
|
|
|
return Some(body.subst(LetBinding(lhs, rhs, Variable(lhs)), lhs))
|
|
|
|
else:
|
|
|
|
# And is already maximally pushed down
|
|
|
|
return None
|
|
|
|
else:
|
|
|
|
# RHS is not recursive
|
|
|
|
return Some(body.subst(rhs, lhs))
|
|
|
|
else:
|
|
|
|
# RHS is used multiple times in the body
|
|
|
|
return None
|
|
|
|
case _:
|
|
|
|
return None
|
|
|
|
|
2024-03-17 13:09:04 +00:00
|
|
|
def collapse_constant_additions(expr: Expression) -> Option[Expression]:
|
|
|
|
match expr:
|
|
|
|
case Application(Builtin(Builtin.BB_PLUS_CONST(x)), Application(Builtin(Builtin.BB_PLUS_CONST(y)), val)):
|
|
|
|
return Some(Application(Builtin(Builtin.BB_PLUS_CONST(x+y)), val))
|
|
|
|
return None
|
|
|
|
|
2024-03-17 13:46:31 +00:00
|
|
|
def eliminate_identity_operations(expr: Expression) -> Option[Expression]:
|
|
|
|
match expr:
|
|
|
|
case Application(Builtin(Builtin.BB_PLUS_CONST(0)), val):
|
|
|
|
return Some(val)
|
|
|
|
case Application(MonoFunc(arg, Variable(bod_var)), val):
|
|
|
|
if arg == bod_var:
|
|
|
|
return Some(val)
|
|
|
|
return None
|
|
|
|
|
|
|
|
def identify_constant_additions(expr: Expression):
|
|
|
|
match expr:
|
|
|
|
case Application(Builtin(Builtin.BB_PLUS()), Int(x)):
|
|
|
|
return Some(Builtin(Builtin.BB_PLUS_CONST(x)))
|
|
|
|
case Application(Application(Builtin(Builtin.BB_PLUS()), expr1), Int(x)):
|
|
|
|
return Some(Application(Builtin(Builtin.BB_PLUS_CONST(x)), expr1))
|
|
|
|
return None
|
|
|
|
|
2024-03-17 02:02:22 +00:00
|
|
|
def apply_opts(optimizations: Collection[Optimization], expression: Expression) -> Tuple[Expression, int]:
|
|
|
|
count: int
|
|
|
|
optimized_expr: Expression
|
|
|
|
match expression:
|
|
|
|
case MonoFunc(arg, body):
|
|
|
|
(optimized_body, count) = apply_opts(optimizations, body)
|
|
|
|
optimized_expr = MonoFunc(arg, optimized_body)
|
|
|
|
case Application(first, arg):
|
|
|
|
(optimized_first, count_first) = apply_opts(optimizations, first)
|
|
|
|
(optimized_arg , count_arg ) = apply_opts(optimizations, arg)
|
2024-03-17 13:45:59 +00:00
|
|
|
optimized_expr = Application(optimized_first, optimized_arg)
|
2024-03-17 02:02:22 +00:00
|
|
|
count = count_first + count_arg
|
|
|
|
case LetBinding(lhs, rhs, body):
|
|
|
|
(optimized_rhs , count_rhs ) = apply_opts(optimizations, rhs)
|
|
|
|
(optimized_body , count_body ) = apply_opts(optimizations, body)
|
|
|
|
optimized_expr = LetBinding(lhs, optimized_rhs, optimized_body)
|
|
|
|
count = count_rhs + count_body
|
|
|
|
case Switch(branches, fallback, switching_on):
|
|
|
|
branch_optimizations = tuple(
|
|
|
|
(branch_num, apply_opts(optimizations, branch_body))
|
|
|
|
for branch_num, branch_body in branches.items())
|
|
|
|
optimized_fallback, count_fallback = apply_opts(optimizations, fallback)
|
|
|
|
optimized_switching_on, count_switching_on = apply_opts(optimizations, switching_on)
|
|
|
|
count = sum(count_branch for _, (_, count_branch) in branch_optimizations) + count_fallback + count_switching_on
|
|
|
|
optimized_branches = {branch_num: optimized_branch for branch_num, (optimized_branch, _) in branch_optimizations}
|
|
|
|
optimized_expr = Switch(optimized_branches, optimized_fallback, optimized_switching_on)
|
2024-03-17 13:45:59 +00:00
|
|
|
case ReplHole(type_bindings, val_bindings):
|
|
|
|
val_optimizations = tuple((name, *apply_opts(optimizations, val)) for name, val in val_bindings)
|
|
|
|
count = sum(count_val_binding for _, _, count_val_binding in val_optimizations)
|
|
|
|
optimized_val_bindings = tuple((name, optimized_val_binding) for name, optimized_val_binding, _ in val_optimizations)
|
|
|
|
optimized_expr = ReplHole(type_bindings, optimized_val_bindings)
|
|
|
|
case Int(_) | Variable(_) | Builtin(_, _) as optimized_expr:
|
2024-03-17 02:02:22 +00:00
|
|
|
count = 0
|
|
|
|
|
|
|
|
def fold_optimizations(acc: Tuple[Expression, int], optimization: Optimization) -> Tuple[Expression, int]:
|
|
|
|
acc_expression, acc_count = acc
|
|
|
|
match optimization(acc_expression):
|
|
|
|
case Some(expr_post_opt):
|
|
|
|
return (expr_post_opt, acc_count + 1)
|
|
|
|
case None:
|
|
|
|
return (acc_expression, acc_count)
|
|
|
|
raise Exception('Unreachable')
|
|
|
|
|
|
|
|
return reduce(fold_optimizations, optimizations, (optimized_expr, 0))
|
|
|
|
|
|
|
|
def optimize_to_fixpoint(optimizations: Collection[Optimization], expression: Expression) -> Expression:
|
|
|
|
match apply_opts(optimizations, expression):
|
|
|
|
case (optimal_expression, 0):
|
|
|
|
return optimal_expression
|
|
|
|
case (optimized_expression, _):
|
|
|
|
return optimize_to_fixpoint(optimizations, optimized_expression)
|
|
|
|
raise Exception('Unreachable')
|
|
|
|
|
2024-03-17 13:46:31 +00:00
|
|
|
all_optimizations: Sequence[Optimization] = (eliminate_single_let,
|
|
|
|
collapse_constant_additions,
|
|
|
|
eliminate_identity_operations,
|
|
|
|
identify_constant_additions)
|