from emis_funky_funktions import * from ir import Expression, LetBinding, count_uses, MonoFunc, Application, LetBinding, Switch, Int, Variable, Builtin, ReplHole from typing import TypeAlias, Collection from functools import reduce Optimization: TypeAlias = Callable[[Expression], Option[Expression]] def eliminate_single_let(expr: Expression) -> Option[Expression]: match expr: case LetBinding(lhs, rhs, body): rhs_is_simple = isinstance(rhs, Int | Variable | Builtin) if count_uses(lhs, body) <= 1 or rhs_is_simple: # RHS is used at most once i nthe body or replication wouldnt be costly if count_uses(lhs, rhs) > 0: # RHS is recursive if not isinstance(body, Variable): # But can still be pushed down return Some(body.subst(LetBinding(lhs, rhs, Variable(lhs)), lhs)) else: # And is already maximally pushed down return None else: # RHS is not recursive return Some(body.subst(rhs, lhs)) else: # RHS is used multiple times in the body return None case _: return None def collapse_constant_additions(expr: Expression) -> Option[Expression]: match expr: case Application(Builtin(Builtin.BB_PLUS_CONST(x)), Application(Builtin(Builtin.BB_PLUS_CONST(y)), val)): return Some(Application(Builtin(Builtin.BB_PLUS_CONST(x+y)), val)) return None def eliminate_identity_operations(expr: Expression) -> Option[Expression]: match expr: case Application(Builtin(Builtin.BB_PLUS_CONST(0)), val): return Some(val) case Application(MonoFunc(arg, Variable(bod_var)), val): if arg == bod_var: return Some(val) return None def identify_constant_additions(expr: Expression): match expr: case Application(Builtin(Builtin.BB_PLUS()), Int(x)): return Some(Builtin(Builtin.BB_PLUS_CONST(x))) case Application(Application(Builtin(Builtin.BB_PLUS()), expr1), Int(x)): return Some(Application(Builtin(Builtin.BB_PLUS_CONST(x)), expr1)) return None def apply_opts(optimizations: Collection[Optimization], expression: Expression) -> Tuple[Expression, int]: count: int optimized_expr: Expression match expression: case MonoFunc(arg, body): (optimized_body, count) = apply_opts(optimizations, body) optimized_expr = MonoFunc(arg, optimized_body) case Application(first, arg): (optimized_first, count_first) = apply_opts(optimizations, first) (optimized_arg , count_arg ) = apply_opts(optimizations, arg) optimized_expr = Application(optimized_first, optimized_arg) count = count_first + count_arg case LetBinding(lhs, rhs, body): (optimized_rhs , count_rhs ) = apply_opts(optimizations, rhs) (optimized_body , count_body ) = apply_opts(optimizations, body) optimized_expr = LetBinding(lhs, optimized_rhs, optimized_body) count = count_rhs + count_body case Switch(branches, fallback, switching_on): branch_optimizations = tuple( (branch_num, apply_opts(optimizations, branch_body)) for branch_num, branch_body in branches.items()) optimized_fallback, count_fallback = apply_opts(optimizations, fallback) optimized_switching_on, count_switching_on = apply_opts(optimizations, switching_on) count = sum(count_branch for _, (_, count_branch) in branch_optimizations) + count_fallback + count_switching_on optimized_branches = {branch_num: optimized_branch for branch_num, (optimized_branch, _) in branch_optimizations} optimized_expr = Switch(optimized_branches, optimized_fallback, optimized_switching_on) case ReplHole(type_bindings, val_bindings): val_optimizations = tuple((name, *apply_opts(optimizations, val)) for name, val in val_bindings) count = sum(count_val_binding for _, _, count_val_binding in val_optimizations) optimized_val_bindings = tuple((name, optimized_val_binding) for name, optimized_val_binding, _ in val_optimizations) optimized_expr = ReplHole(type_bindings, optimized_val_bindings) case Int(_) | Variable(_) | Builtin(_, _) as optimized_expr: count = 0 def fold_optimizations(acc: Tuple[Expression, int], optimization: Optimization) -> Tuple[Expression, int]: acc_expression, acc_count = acc match optimization(acc_expression): case Some(expr_post_opt): return (expr_post_opt, acc_count + 1) case None: return (acc_expression, acc_count) raise Exception('Unreachable') return reduce(fold_optimizations, optimizations, (optimized_expr, 0)) def optimize_to_fixpoint(optimizations: Collection[Optimization], expression: Expression) -> Expression: match apply_opts(optimizations, expression): case (optimal_expression, 0): return optimal_expression case (optimized_expression, _): return optimize_to_fixpoint(optimizations, optimized_expression) raise Exception('Unreachable') all_optimizations: Sequence[Optimization] = (eliminate_single_let, collapse_constant_additions, eliminate_identity_operations, identify_constant_additions)