diff --git a/ir.py b/ir.py index bbf5fd9..1bada00 100644 --- a/ir.py +++ b/ir.py @@ -38,11 +38,30 @@ class ReplHole: if var_name not in types_.BUILTINS_CONTEXT ) +BuiltinBehavior: TypeAlias = 'Builtin.BB_PLUS_CONST | Builtin.BB_PLUS' + @dataclass(frozen=True) class Builtin: - name: str - f: Callable[[Expression], Option[Expression]] - js: str + behavior: BuiltinBehavior + + @dataclass(frozen=True) + class BB_PLUS_CONST: + amt: int + def name(self) -> str: + return f'{self.amt:+}' + def js(self) -> str: + return f'(x=>x{self.amt:+})' + def run(self, e: Expression) -> Option[Expression]: + return Some(Int(e.value + self.amt)) if isinstance(e, Int) else None + + @dataclass(frozen=True) + class BB_PLUS: + def name(self) -> str: + return '+' + def js(self) -> str: + return '(x=>y=>x+y)' + def run(self, e: Expression) -> Option[Expression]: + return Some(Builtin(Builtin.BB_PLUS_CONST(e.value))) if isinstance(e, Int) else None def subst(self, expression: Expression, variable: str) -> Expression: return self @@ -54,36 +73,16 @@ class Builtin: return None def try_apply(self, v: Expression) -> Option[Expression]: - return self.f(v) + return self.behavior.run(v) def __repr__(self) -> str: - return "'" + repr(self.name)[1:-1] + "'" + return "'" + repr(self.behavior.name())[1:-1] + "'" def codegen(self) -> str: - return self.js + return self.behavior.js() - @cur2 - @staticmethod - def _PLUS_CONST(i: int, e: Expression) -> Option[Expression]: - match e: - case Int(v): - return Some(Int(i + v)) - return None - - @staticmethod - def _PLUS(e: Expression) -> Option[Expression]: - match e: - case Int(v): - return Some(Builtin(f'+{v}', Builtin._PLUS_CONST(v), f'(x => x + {v})')) - return None - - @staticmethod - def PLUS() -> 'Builtin': - return Builtin('+', Builtin._PLUS, '(x => y => x + y)') - - @staticmethod - def S() -> 'Builtin': - return Builtin('S', Builtin._PLUS_CONST(1), '(x => x + 1)') + PLUS: 'Callable[[], Builtin]' = lambda: Builtin(Builtin.BB_PLUS()) + S: 'Callable[[], Builtin]' = lambda: Builtin(Builtin.BB_PLUS_CONST(1)) BUILTIN_SUBSTITUTIONS: Sequence[Tuple[str, Expression]] = ( ('+', Builtin.PLUS()), @@ -223,12 +222,10 @@ class Application: return unwrap_opt(self.first.try_apply(self.arg)).codegen() else: match self.first: - case Application(Builtin('+', _, _), addend1): + case Application(Builtin(Builtin.BB_PLUS), addend1): return f'({addend1.codegen()} + {self.arg.codegen()})' - case Builtin('S', _, _): - return f'(1+{self.arg.codegen()})' - case Builtin('pred', _, _): - return f'({self.arg.codegen()}-1)' + case Builtin(Builtin.BB_PLUS_CONST(n)): + return f'({self.arg.codegen()}{n:+})' return f'({self.first.codegen()})({self.arg.codegen()})' @dataclass @@ -340,7 +337,7 @@ def compile_tree(tree: 'MatchTree[Expression]', match_against: Expression) -> Re def location_to_ir(location: StructurePath) -> Callable[[Expression], Expression]: def access_location(part: int) -> Callable[[Expression], Expression]: def remove(expr: Expression) -> Expression: - return Application(Builtin(f'pred', Builtin._PLUS_CONST(-1), f'$=>$-1'), expr) + return Application(Builtin(Builtin.BB_PLUS_CONST(-1)), expr) def access_location_prime(expr: Expression) -> Expression: if part < 1: return remove(expr)