from emis_funky_funktions import * from typing import Collection, Mapping, Sequence, Tuple, TypeAlias import types_ Expression: TypeAlias = 'Function | Application | Int | Variable | Builtin | LetBinding | ReplHole' Pattern: TypeAlias = 'NamePattern | IntPattern | SPattern | IgnorePattern' @dataclass(frozen=True) class NamePattern: """ A pattern which always succeeds to match, and binds a whole expression to a name """ name: str def binds(self, var: str) -> bool: """ Test to see if this pattern binds a given variable """ return var == self.name def match(self, e: Expression) -> Option[Sequence[Tuple[str, Expression]]]: """ Match an expression against this pattern >>> NamePattern('my_var').match(Int(1)) Some((('my_var', 1),)) """ return Some(((self.name, e),)) def __repr__(self) -> str: return self.name def codegen(self, match_on: str) -> Tuple[Sequence[str], Collection[Tuple[str, str]]]: return (tuple(), [(self.name, match_on)]) @dataclass(frozen=True) class IgnorePattern: """ A pattern which always succeeds to match, but binds nothing """ def binds(self, var: str) -> bool: """ Test to see if this pattern binds a given variable For an `IgnorePattern` this is always false """ return False def match(self, e: Expression) -> Option[Sequence[Tuple[str, Expression]]]: """ Match an expression against this pattern >>> IgnorePattern().match(Int(1)) Some(()) """ return Some(tuple()) def __repr__(self) -> str: return '_' def codegen(self, match_on: str) -> Tuple[Sequence[str], Collection[Tuple[str, str]]]: return (tuple(), tuple()) @dataclass(frozen=True) class IntPattern: value: int def binds(self, var: str) -> bool: """ Test to see if this pattern binds a given variable For an `IntPattern` this is always false """ return False def match(self, e: Expression) -> Option[Sequence[Tuple[str, Expression]]]: """ Match an expression against this pattern >>> IntPattern(2).match(Int(1)) is None True >>> IntPattern(1).match(Int(1)) Some(()) """ match e: case Int(v) if v == self.value: return Some(tuple()) return None def __repr__(self) -> str: return repr(self.value) def codegen(self, match_on: str) -> Tuple[Sequence[str], Collection[Tuple[str, str]]]: return ((f'{match_on}=={self.value}',), tuple()) @dataclass(frozen=True) class SPattern: pred: Pattern def binds(self, var: str) -> bool: """ Test to see if this pattern binds a given variable """ return self.pred.binds(var) def match(self, e: Expression) -> Option[Sequence[Tuple[str, Expression]]]: """ Match an expression against this pattern >>> SPattern(NamePattern('n')).match(Int(1)) Some((('n', 0),)) >>> SPattern(NamePattern('n')).match(Int(0)) is None True >>> SPattern(SPattern(NamePattern('n'))).match(Int(4)) Some((('n', 2),)) """ match e: case Int(v) if v > 0: return self.pred.match(Int(v - 1)) return None def __repr__(self) -> str: return 'S ' + repr(self.pred) def codegen(self, match_on: str) -> Tuple[Sequence[str], Collection[Tuple[str, str]]]: pred_conditions, pred_bindings = self.pred.codegen(f'({match_on}-1)') return ((f'{match_on}>0', *pred_conditions), pred_bindings) @dataclass(frozen=True) class ReplHole: typ_bindings: types_.Context val_bindings: Sequence[Tuple[str, Expression]] = tuple() def subst(self, expression: Expression, variable: str) -> Expression: return ReplHole(self.typ_bindings, (*self.val_bindings, (variable, expression))) def is_value(self) -> bool: return True def step(self) -> Option[Expression]: return None def __repr__(self) -> str: return "[]" def codegen(self) -> str: return '[]' def render(self) -> str: return '\n'.join( f'{var_name} = ({var_expr.codegen()});' for (var_name, var_expr) in self.val_bindings ) @dataclass(frozen=True) class Builtin: name: str f: Callable[[Expression], Option[Expression]] js: str def subst(self, expression: Expression, variable: str) -> Expression: return self def is_value(self) -> bool: return True def step(self) -> Option[Expression]: return None def try_apply(self, v: Expression) -> Option[Expression]: return self.f(v) def __repr__(self) -> str: return "'" + repr(self.name)[1:-1] + "'" def codegen(self) -> str: return self.js @cur2 @staticmethod def _PLUS_CONST(i: int, e: Expression) -> Option[Expression]: match e: case Int(v): return Some(Int(i + v)) return None @staticmethod def _PLUS(e: Expression) -> Option[Expression]: match e: case Int(v): return Some(Builtin(f'+{v}', Builtin._PLUS_CONST(v), f'(x => x + {v})')) return None @staticmethod def PLUS() -> 'Builtin': return Builtin('+', Builtin._PLUS, '(x => y => x + y)') @staticmethod def S() -> 'Builtin': return Builtin('S', Builtin._PLUS_CONST(1), '(x => x + 1)') BUILTIN_SUBSTITUTIONS: Sequence[Tuple[str, Expression]] = ( ('+', Builtin.PLUS()), ('S', Builtin.S()), ) @dataclass(frozen=True) class Function: forms: Sequence[Tuple[Pattern, Expression]] def subst(self, expression: Expression, variable: str) -> Expression: return Function([ (p, e if p.binds(variable) else e.subst(expression, variable)) for (p, e) in self.forms ]) def is_value(self) -> bool: return True def step(self) -> Option[Expression]: return None def try_apply(self, v: Expression) -> Option[Expression]: match tuple((bindings.val, body) for (pattern, body) in self.forms for bindings in (pattern.match(v),) if bindings is not None): case []: return None case [(bindings, body), *rest]: return Some(subst_all(bindings, body)) raise Exception('Unreachable') def codegen_inner(self) -> str: return (':'.join( ( '&&'.join( iter(pattern.codegen('$')[0]) ) or '1' ) + '?' + ( '((' + ','.join( binding_name for (binding_name, binding_value) in pattern.codegen('$')[1] ) + f')=>({branch.codegen()}))(' + ','.join( binding_value for (binding_name, binding_value) in pattern.codegen('$')[1] ) + ')' if len(pattern.codegen('$')[1]) else branch.codegen() ) for (pattern, branch) in self.forms ) or '0?[].b') + ':[].b' def codegen(self) -> str: return '$=>' + self.codegen_inner() def codegen_named(self, name) -> str: return f'function {name}($){{return {self.codegen_inner()}}}' def __repr__(self) -> str: return '{ ' + ', '.join('"' + repr(repr(p))[1:-1] + '" : ' + repr(e) for (p, e) in self.forms) + ' }' @dataclass class LetBinding: lhs: str rhs: Expression body: Expression def subst(self, expression: Expression, variable: str) -> Expression: if self.lhs == variable: return self else: return LetBinding( self.lhs, self.rhs.subst(expression, variable), self.body.subst(expression, variable) ) def is_value(self) -> bool: return False def step(self) -> Option[Expression]: if self.rhs.is_value(): return Some(self.body.subst( self.rhs.subst( LetBinding(self.lhs, self.rhs, Variable(self.lhs)), self.lhs ), self.lhs )) else: return map_opt(lambda rhs_step: LetBinding(self.lhs, rhs_step, self.body), self.rhs.step() ) def __repr__(self) -> str: return f'( {repr(self.lhs)}, {repr(self.rhs)}, {repr(self.body)} )' def codegen(self) -> str: rhs_cg = self.rhs.codegen_named(self.lhs) if isinstance(self.rhs, Function) else self.rhs.codegen() return f'(({self.lhs}) => {self.body.codegen()})({rhs_cg})' @dataclass class Application: first: Expression arg: Expression def subst(self, expression: Expression, variable: str) -> Expression: return Application( self.first.subst(expression, variable), self.arg.subst(expression, variable) ) def is_value(self) -> bool: return False def step(self) -> Option[Expression]: match self.first.step(): case Some(first_stepped): return Some(Application(first_stepped, self.arg)) case None: match self.arg.step(): case Some(arg_stepped): return Some(Application(self.first, arg_stepped)) case None: assert isinstance(self.first, Function) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed" return self.first.try_apply(self.arg) raise Exception('Unreachable') def __repr__(self) -> str: return f'[ {repr(self.first)}, {repr(self.arg)} ]' def codegen(self) -> str: if isinstance(self.first, Function | Builtin) and self.arg.is_value(): return unwrap_opt(self.first.try_apply(self.arg)).codegen() else: return f'({self.first.codegen()})({self.arg.codegen()})' @dataclass class Int: value: int def subst(self, expression: Expression, variable: str) -> Expression: return self def is_value(self) -> bool: return True def step(self) -> Option[Expression]: return None def __repr__(self) -> str: return str(self.value) def codegen(self) -> str: return str(self.value) @dataclass class Variable: name: str def subst(self, expression: Expression, variable: str) -> Expression: if variable == self.name: return expression else: return self def is_value(self) -> bool: return False def step(self) -> Option[Expression]: match self.name: case '+': return Some(Builtin.PLUS()) case 'S': return Some(Builtin.S()) return None def __repr__(self) -> str: return '"' + repr(self.name)[1:-1] + '"' def codegen(self) -> str: return self.name def subst_all(bindings: Sequence[Tuple[str, Expression]], body: Expression) -> Expression: match bindings: case []: return body case [(var, replacement), *rest]: return subst_all(rest, body.subst(replacement, var)) raise Exception('Unreachable')