JSON-Lang/opt.py

116 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from emis_funky_funktions import *
from ir import Expression, LetBinding, count_uses, MonoFunc, Application, LetBinding, Switch, Int, Variable, Builtin, ReplHole
from typing import TypeAlias, Collection
from functools import reduce
Optimization: TypeAlias = Callable[[Expression], Option[Expression]]
def eliminate_single_let(expr: Expression) -> Option[Expression]:
match expr:
case LetBinding(lhs, rhs, body):
if count_uses(lhs, body) <= 1:
# RHS is used at most once i nthe body
if count_uses(lhs, rhs) > 0:
# RHS is recursive
if not isinstance(body, Variable):
# But can still be pushed down
return Some(body.subst(LetBinding(lhs, rhs, Variable(lhs)), lhs))
else:
# And is already maximally pushed down
return None
else:
# RHS is not recursive
return Some(body.subst(rhs, lhs))
else:
# RHS is used multiple times in the body
return None
case _:
return None
def collapse_constant_additions(expr: Expression) -> Option[Expression]:
match expr:
case Application(Builtin(Builtin.BB_PLUS_CONST(x)), Application(Builtin(Builtin.BB_PLUS_CONST(y)), val)):
return Some(Application(Builtin(Builtin.BB_PLUS_CONST(x+y)), val))
return None
def eliminate_identity_operations(expr: Expression) -> Option[Expression]:
match expr:
case Application(Builtin(Builtin.BB_PLUS_CONST(0)), val):
print('BBBBBBBB')
return Some(val)
case Application(MonoFunc(arg, Variable(bod_var)), val):
if arg == bod_var:
print('CCCCCCCCCCCCCcc')
return Some(val)
return None
def identify_constant_additions(expr: Expression):
match expr:
case Application(Builtin(Builtin.BB_PLUS()), Int(x)):
print('AAAααααα')
return Some(Builtin(Builtin.BB_PLUS_CONST(x)))
case Application(Application(Builtin(Builtin.BB_PLUS()), expr1), Int(x)):
print('DDDDDDDDDDD')
return Some(Application(Builtin(Builtin.BB_PLUS_CONST(x)), expr1))
return None
def apply_opts(optimizations: Collection[Optimization], expression: Expression) -> Tuple[Expression, int]:
count: int
optimized_expr: Expression
match expression:
case MonoFunc(arg, body):
(optimized_body, count) = apply_opts(optimizations, body)
optimized_expr = MonoFunc(arg, optimized_body)
case Application(first, arg):
(optimized_first, count_first) = apply_opts(optimizations, first)
(optimized_arg , count_arg ) = apply_opts(optimizations, arg)
optimized_expr = Application(optimized_first, optimized_arg)
count = count_first + count_arg
case LetBinding(lhs, rhs, body):
(optimized_rhs , count_rhs ) = apply_opts(optimizations, rhs)
(optimized_body , count_body ) = apply_opts(optimizations, body)
optimized_expr = LetBinding(lhs, optimized_rhs, optimized_body)
count = count_rhs + count_body
case Switch(branches, fallback, switching_on):
branch_optimizations = tuple(
(branch_num, apply_opts(optimizations, branch_body))
for branch_num, branch_body in branches.items())
optimized_fallback, count_fallback = apply_opts(optimizations, fallback)
optimized_switching_on, count_switching_on = apply_opts(optimizations, switching_on)
count = sum(count_branch for _, (_, count_branch) in branch_optimizations) + count_fallback + count_switching_on
optimized_branches = {branch_num: optimized_branch for branch_num, (optimized_branch, _) in branch_optimizations}
optimized_expr = Switch(optimized_branches, optimized_fallback, optimized_switching_on)
case ReplHole(type_bindings, val_bindings):
val_optimizations = tuple((name, *apply_opts(optimizations, val)) for name, val in val_bindings)
print('VAL_OPTS:', val_optimizations)
count = sum(count_val_binding for _, _, count_val_binding in val_optimizations)
optimized_val_bindings = tuple((name, optimized_val_binding) for name, optimized_val_binding, _ in val_optimizations)
optimized_expr = ReplHole(type_bindings, optimized_val_bindings)
case Int(_) | Variable(_) | Builtin(_, _) as optimized_expr:
count = 0
def fold_optimizations(acc: Tuple[Expression, int], optimization: Optimization) -> Tuple[Expression, int]:
acc_expression, acc_count = acc
match optimization(acc_expression):
case Some(expr_post_opt):
return (expr_post_opt, acc_count + 1)
case None:
return (acc_expression, acc_count)
raise Exception('Unreachable')
return reduce(fold_optimizations, optimizations, (optimized_expr, 0))
def optimize_to_fixpoint(optimizations: Collection[Optimization], expression: Expression) -> Expression:
match apply_opts(optimizations, expression):
case (optimal_expression, 0):
return optimal_expression
case (optimized_expression, _):
return optimize_to_fixpoint(optimizations, optimized_expression)
raise Exception('Unreachable')
all_optimizations: Sequence[Optimization] = (eliminate_single_let,
collapse_constant_additions,
eliminate_identity_operations,
identify_constant_additions)