Compare commits
4 Commits
b0ccfc6309
...
864d98638e
Author | SHA1 | Date |
---|---|---|
Emi Simpson | 864d98638e | |
Emi Simpson | 4f39a88f47 | |
Emi Simpson | a2b68cc73e | |
Emi Simpson | 37c809f5c0 |
65
ir.py
65
ir.py
|
@ -38,11 +38,30 @@ class ReplHole:
|
||||||
if var_name not in types_.BUILTINS_CONTEXT
|
if var_name not in types_.BUILTINS_CONTEXT
|
||||||
)
|
)
|
||||||
|
|
||||||
|
BuiltinBehavior: TypeAlias = 'Builtin.BB_PLUS_CONST | Builtin.BB_PLUS'
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Builtin:
|
class Builtin:
|
||||||
name: str
|
behavior: BuiltinBehavior
|
||||||
f: Callable[[Expression], Option[Expression]]
|
|
||||||
js: str
|
@dataclass(frozen=True)
|
||||||
|
class BB_PLUS_CONST:
|
||||||
|
amt: int
|
||||||
|
def name(self) -> str:
|
||||||
|
return f'{self.amt:+}'
|
||||||
|
def js(self) -> str:
|
||||||
|
return f'(x=>x{self.amt:+})'
|
||||||
|
def run(self, e: Expression) -> Option[Expression]:
|
||||||
|
return Some(Int(e.value + self.amt)) if isinstance(e, Int) else None
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class BB_PLUS:
|
||||||
|
def name(self) -> str:
|
||||||
|
return '+'
|
||||||
|
def js(self) -> str:
|
||||||
|
return '(x=>y=>x+y)'
|
||||||
|
def run(self, e: Expression) -> Option[Expression]:
|
||||||
|
return Some(Builtin(Builtin.BB_PLUS_CONST(e.value))) if isinstance(e, Int) else None
|
||||||
|
|
||||||
def subst(self, expression: Expression, variable: str) -> Expression:
|
def subst(self, expression: Expression, variable: str) -> Expression:
|
||||||
return self
|
return self
|
||||||
|
@ -54,36 +73,16 @@ class Builtin:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def try_apply(self, v: Expression) -> Option[Expression]:
|
def try_apply(self, v: Expression) -> Option[Expression]:
|
||||||
return self.f(v)
|
return self.behavior.run(v)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "'" + repr(self.name)[1:-1] + "'"
|
return "'" + repr(self.behavior.name())[1:-1] + "'"
|
||||||
|
|
||||||
def codegen(self) -> str:
|
def codegen(self) -> str:
|
||||||
return self.js
|
return self.behavior.js()
|
||||||
|
|
||||||
@cur2
|
PLUS: 'Callable[[], Builtin]' = lambda: Builtin(Builtin.BB_PLUS())
|
||||||
@staticmethod
|
S: 'Callable[[], Builtin]' = lambda: Builtin(Builtin.BB_PLUS_CONST(1))
|
||||||
def _PLUS_CONST(i: int, e: Expression) -> Option[Expression]:
|
|
||||||
match e:
|
|
||||||
case Int(v):
|
|
||||||
return Some(Int(i + v))
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _PLUS(e: Expression) -> Option[Expression]:
|
|
||||||
match e:
|
|
||||||
case Int(v):
|
|
||||||
return Some(Builtin(f'+{v}', Builtin._PLUS_CONST(v), f'(x => x + {v})'))
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def PLUS() -> 'Builtin':
|
|
||||||
return Builtin('+', Builtin._PLUS, '(x => y => x + y)')
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def S() -> 'Builtin':
|
|
||||||
return Builtin('S', Builtin._PLUS_CONST(1), '(x => x + 1)')
|
|
||||||
|
|
||||||
BUILTIN_SUBSTITUTIONS: Sequence[Tuple[str, Expression]] = (
|
BUILTIN_SUBSTITUTIONS: Sequence[Tuple[str, Expression]] = (
|
||||||
('+', Builtin.PLUS()),
|
('+', Builtin.PLUS()),
|
||||||
|
@ -223,12 +222,10 @@ class Application:
|
||||||
return unwrap_opt(self.first.try_apply(self.arg)).codegen()
|
return unwrap_opt(self.first.try_apply(self.arg)).codegen()
|
||||||
else:
|
else:
|
||||||
match self.first:
|
match self.first:
|
||||||
case Application(Builtin('+', _, _), addend1):
|
case Application(Builtin(Builtin.BB_PLUS), addend1):
|
||||||
return f'({addend1.codegen()} + {self.arg.codegen()})'
|
return f'({addend1.codegen()} + {self.arg.codegen()})'
|
||||||
case Builtin('S', _, _):
|
case Builtin(Builtin.BB_PLUS_CONST(n)):
|
||||||
return f'(1+{self.arg.codegen()})'
|
return f'({self.arg.codegen()}{n:+})'
|
||||||
case Builtin('pred', _, _):
|
|
||||||
return f'({self.arg.codegen()}-1)'
|
|
||||||
return f'({self.first.codegen()})({self.arg.codegen()})'
|
return f'({self.first.codegen()})({self.arg.codegen()})'
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -340,7 +337,7 @@ def compile_tree(tree: 'MatchTree[Expression]', match_against: Expression) -> Re
|
||||||
def location_to_ir(location: StructurePath) -> Callable[[Expression], Expression]:
|
def location_to_ir(location: StructurePath) -> Callable[[Expression], Expression]:
|
||||||
def access_location(part: int) -> Callable[[Expression], Expression]:
|
def access_location(part: int) -> Callable[[Expression], Expression]:
|
||||||
def remove(expr: Expression) -> Expression:
|
def remove(expr: Expression) -> Expression:
|
||||||
return Application(Builtin(f'pred', Builtin._PLUS_CONST(-1), f'$=>$-1'), expr)
|
return Application(Builtin(Builtin.BB_PLUS_CONST(-1)), expr)
|
||||||
def access_location_prime(expr: Expression) -> Expression:
|
def access_location_prime(expr: Expression) -> Expression:
|
||||||
if part < 1:
|
if part < 1:
|
||||||
return remove(expr)
|
return remove(expr)
|
||||||
|
|
42
opt.py
42
opt.py
|
@ -29,6 +29,33 @@ def eliminate_single_let(expr: Expression) -> Option[Expression]:
|
||||||
case _:
|
case _:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def collapse_constant_additions(expr: Expression) -> Option[Expression]:
|
||||||
|
match expr:
|
||||||
|
case Application(Builtin(Builtin.BB_PLUS_CONST(x)), Application(Builtin(Builtin.BB_PLUS_CONST(y)), val)):
|
||||||
|
return Some(Application(Builtin(Builtin.BB_PLUS_CONST(x+y)), val))
|
||||||
|
return None
|
||||||
|
|
||||||
|
def eliminate_identity_operations(expr: Expression) -> Option[Expression]:
|
||||||
|
match expr:
|
||||||
|
case Application(Builtin(Builtin.BB_PLUS_CONST(0)), val):
|
||||||
|
print('BBBBBBBB')
|
||||||
|
return Some(val)
|
||||||
|
case Application(MonoFunc(arg, Variable(bod_var)), val):
|
||||||
|
if arg == bod_var:
|
||||||
|
print('CCCCCCCCCCCCCcc')
|
||||||
|
return Some(val)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def identify_constant_additions(expr: Expression):
|
||||||
|
match expr:
|
||||||
|
case Application(Builtin(Builtin.BB_PLUS()), Int(x)):
|
||||||
|
print('AAAααααα')
|
||||||
|
return Some(Builtin(Builtin.BB_PLUS_CONST(x)))
|
||||||
|
case Application(Application(Builtin(Builtin.BB_PLUS()), expr1), Int(x)):
|
||||||
|
print('DDDDDDDDDDD')
|
||||||
|
return Some(Application(Builtin(Builtin.BB_PLUS_CONST(x)), expr1))
|
||||||
|
return None
|
||||||
|
|
||||||
def apply_opts(optimizations: Collection[Optimization], expression: Expression) -> Tuple[Expression, int]:
|
def apply_opts(optimizations: Collection[Optimization], expression: Expression) -> Tuple[Expression, int]:
|
||||||
count: int
|
count: int
|
||||||
optimized_expr: Expression
|
optimized_expr: Expression
|
||||||
|
@ -39,7 +66,7 @@ def apply_opts(optimizations: Collection[Optimization], expression: Expression)
|
||||||
case Application(first, arg):
|
case Application(first, arg):
|
||||||
(optimized_first, count_first) = apply_opts(optimizations, first)
|
(optimized_first, count_first) = apply_opts(optimizations, first)
|
||||||
(optimized_arg , count_arg ) = apply_opts(optimizations, arg)
|
(optimized_arg , count_arg ) = apply_opts(optimizations, arg)
|
||||||
optimized_expr = Application(first, arg)
|
optimized_expr = Application(optimized_first, optimized_arg)
|
||||||
count = count_first + count_arg
|
count = count_first + count_arg
|
||||||
case LetBinding(lhs, rhs, body):
|
case LetBinding(lhs, rhs, body):
|
||||||
(optimized_rhs , count_rhs ) = apply_opts(optimizations, rhs)
|
(optimized_rhs , count_rhs ) = apply_opts(optimizations, rhs)
|
||||||
|
@ -55,7 +82,13 @@ def apply_opts(optimizations: Collection[Optimization], expression: Expression)
|
||||||
count = sum(count_branch for _, (_, count_branch) in branch_optimizations) + count_fallback + count_switching_on
|
count = sum(count_branch for _, (_, count_branch) in branch_optimizations) + count_fallback + count_switching_on
|
||||||
optimized_branches = {branch_num: optimized_branch for branch_num, (optimized_branch, _) in branch_optimizations}
|
optimized_branches = {branch_num: optimized_branch for branch_num, (optimized_branch, _) in branch_optimizations}
|
||||||
optimized_expr = Switch(optimized_branches, optimized_fallback, optimized_switching_on)
|
optimized_expr = Switch(optimized_branches, optimized_fallback, optimized_switching_on)
|
||||||
case Int(_) | Variable(_) | Builtin(_, _) | ReplHole(_, _) as optimized_expr:
|
case ReplHole(type_bindings, val_bindings):
|
||||||
|
val_optimizations = tuple((name, *apply_opts(optimizations, val)) for name, val in val_bindings)
|
||||||
|
print('VAL_OPTS:', val_optimizations)
|
||||||
|
count = sum(count_val_binding for _, _, count_val_binding in val_optimizations)
|
||||||
|
optimized_val_bindings = tuple((name, optimized_val_binding) for name, optimized_val_binding, _ in val_optimizations)
|
||||||
|
optimized_expr = ReplHole(type_bindings, optimized_val_bindings)
|
||||||
|
case Int(_) | Variable(_) | Builtin(_, _) as optimized_expr:
|
||||||
count = 0
|
count = 0
|
||||||
|
|
||||||
def fold_optimizations(acc: Tuple[Expression, int], optimization: Optimization) -> Tuple[Expression, int]:
|
def fold_optimizations(acc: Tuple[Expression, int], optimization: Optimization) -> Tuple[Expression, int]:
|
||||||
|
@ -77,4 +110,7 @@ def optimize_to_fixpoint(optimizations: Collection[Optimization], expression: Ex
|
||||||
return optimize_to_fixpoint(optimizations, optimized_expression)
|
return optimize_to_fixpoint(optimizations, optimized_expression)
|
||||||
raise Exception('Unreachable')
|
raise Exception('Unreachable')
|
||||||
|
|
||||||
all_optimizations: Sequence[Optimization] = (eliminate_single_let,)
|
all_optimizations: Sequence[Optimization] = (eliminate_single_let,
|
||||||
|
collapse_constant_additions,
|
||||||
|
eliminate_identity_operations,
|
||||||
|
identify_constant_additions)
|
Loading…
Reference in New Issue