diff --git a/opt.py b/opt.py index 4c8c45f..050a130 100644 --- a/opt.py +++ b/opt.py @@ -45,7 +45,7 @@ def apply_opts(optimizations: Collection[Optimization], expression: Expression) case Application(first, arg): (optimized_first, count_first) = apply_opts(optimizations, first) (optimized_arg , count_arg ) = apply_opts(optimizations, arg) - optimized_expr = Application(first, arg) + optimized_expr = Application(optimized_first, optimized_arg) count = count_first + count_arg case LetBinding(lhs, rhs, body): (optimized_rhs , count_rhs ) = apply_opts(optimizations, rhs) @@ -61,7 +61,13 @@ def apply_opts(optimizations: Collection[Optimization], expression: Expression) count = sum(count_branch for _, (_, count_branch) in branch_optimizations) + count_fallback + count_switching_on optimized_branches = {branch_num: optimized_branch for branch_num, (optimized_branch, _) in branch_optimizations} optimized_expr = Switch(optimized_branches, optimized_fallback, optimized_switching_on) - case Int(_) | Variable(_) | Builtin(_, _) | ReplHole(_, _) as optimized_expr: + case ReplHole(type_bindings, val_bindings): + val_optimizations = tuple((name, *apply_opts(optimizations, val)) for name, val in val_bindings) + print('VAL_OPTS:', val_optimizations) + count = sum(count_val_binding for _, _, count_val_binding in val_optimizations) + optimized_val_bindings = tuple((name, optimized_val_binding) for name, optimized_val_binding, _ in val_optimizations) + optimized_expr = ReplHole(type_bindings, optimized_val_bindings) + case Int(_) | Variable(_) | Builtin(_, _) as optimized_expr: count = 0 def fold_optimizations(acc: Tuple[Expression, int], optimization: Optimization) -> Tuple[Expression, int]: