from emis_funky_funktions import * from typing import Collection, Mapping, Sequence, Tuple, TypeAlias from functools import reduce from match_tree import MatchTree, MatchException, StructurePath, LeafNode, merge_all_trees, IntNode, EMPTY_STRUCT_PATH, FAIL_NODE from patterns import Pattern import types_ Expression: TypeAlias = 'MonoFunc | Application | Int | Variable | Builtin | LetBinding | ReplHole | Switch' Value: TypeAlias = 'MonoFunc | Int | Builtin | ReplHole' @dataclass(frozen=True) class ReplHole: typ_bindings: types_.Context val_bindings: Sequence[Tuple[str, Expression]] = tuple() def subst(self, expression: Expression, variable: str) -> Expression: return ReplHole(self.typ_bindings, (*self.val_bindings, (variable, expression))) def is_value(self) -> bool: return True def step(self) -> Option[Expression]: return None def __repr__(self) -> str: return "[]" def codegen(self) -> str: return '[]' def render(self) -> str: return '\n'.join( f'const {var_name} = ({var_expr.codegen()});' for (var_name, var_expr) in self.val_bindings if var_name not in types_.BUILTINS_CONTEXT ) @dataclass(frozen=True) class Builtin: name: str f: Callable[[Expression], Option[Expression]] js: str def subst(self, expression: Expression, variable: str) -> Expression: return self def is_value(self) -> bool: return True def step(self) -> Option[Expression]: return None def try_apply(self, v: Expression) -> Option[Expression]: return self.f(v) def __repr__(self) -> str: return "'" + repr(self.name)[1:-1] + "'" def codegen(self) -> str: return self.js @cur2 @staticmethod def _PLUS_CONST(i: int, e: Expression) -> Option[Expression]: match e: case Int(v): return Some(Int(i + v)) return None @staticmethod def _PLUS(e: Expression) -> Option[Expression]: match e: case Int(v): return Some(Builtin(f'+{v}', Builtin._PLUS_CONST(v), f'(x => x + {v})')) return None @staticmethod def PLUS() -> 'Builtin': return Builtin('+', Builtin._PLUS, '(x => y => x + y)') @staticmethod def S() -> 'Builtin': return Builtin('S', Builtin._PLUS_CONST(1), '(x => x + 1)') BUILTIN_SUBSTITUTIONS: Sequence[Tuple[str, Expression]] = ( ('+', Builtin.PLUS()), ('S', Builtin.S()), ) @dataclass(frozen=True) class MonoFunc: arg: str body: Expression def subst(self, expression: Expression, variable: str) -> Expression: if variable == self.arg: return self else: return MonoFunc(self.arg, self.body.subst(expression, variable)) def is_value(self) -> bool: return True def step(self) -> Option[Expression]: return None @staticmethod def from_match_function(forms: 'Sequence[Tuple[Pattern, Expression]]') -> Result[Expression, MatchException]: # In certain cases, starting a function with a full match tree may be unnecessary. # Specifically, if there exists only one possible branch and that branch binds only # one value and that value is equal to the whole entire input, rather than assigning # that input to a new variable, we may simply use argument variable instead. match forms: case [(patt, body)]: # A single possible branch match patt.bindings(): case []: # Binds nothing return Ok(MonoFunc('_', body)) case [(var, [])]: # Binds a single variable to the entire input return Ok(MonoFunc(var, body)) # If those special cases fail, we eliminate the pattern matching to produce a # single body: match_trees = tuple( # Construct a match tree for each possible branch pattern.match_tree( EMPTY_STRUCT_PATH, LeafNode.from_value(bindings_to_lets(pattern.bindings(), Variable('$'), body)) ) for (pattern, body) in forms ) unified_match_tree = merge_all_trees(match_trees) # Unify all the trees compiled_tree = compile_tree(unified_match_tree, Variable('$')) # Turn each tree into IR return compiled_tree <= p(MonoFunc, '$') def try_apply(self, v: Expression) -> Option[Expression]: return Some(self.body.subst(v, self.arg)) def codegen(self) -> str: return f'({self.arg}=>{self.body.codegen()})' def codegen_named(self, name) -> str: return f'(function {name}({self.arg}){{return {self.body.codegen()}}})' def __repr__(self) -> str: return f'{{{repr(self.arg)}: {repr(self.body)}}}' @dataclass class LetBinding: lhs: str rhs: Expression body: Expression def subst(self, expression: Expression, variable: str) -> Expression: if self.lhs == variable: return self else: return LetBinding( self.lhs, self.rhs.subst(expression, variable), self.body.subst(expression, variable) ) def is_value(self) -> bool: return False def step(self) -> Option[Expression]: if self.rhs.is_value(): return Some(self.body.subst( self.rhs.subst( LetBinding(self.lhs, self.rhs, Variable(self.lhs)), self.lhs ), self.lhs )) else: return map_opt(lambda rhs_step: LetBinding(self.lhs, rhs_step, self.body), self.rhs.step() ) def __repr__(self) -> str: return f'( "{self.lhs}", {repr(self.rhs)}, {repr(self.body)} )' def codegen(self) -> str: rhs_cg = self.rhs.codegen_named(self.lhs) if isinstance(self.rhs, MonoFunc) else self.rhs.codegen() return f'({self.lhs}=>{self.body.codegen()})({rhs_cg})' @dataclass class Application: first: Expression arg: Expression def subst(self, expression: Expression, variable: str) -> Expression: return Application( self.first.subst(expression, variable), self.arg.subst(expression, variable) ) def is_value(self) -> bool: return False def step(self) -> Option[Expression]: match self.first.step(): case Some(first_stepped): return Some(Application(first_stepped, self.arg)) case None: match self.arg.step(): case Some(arg_stepped): return Some(Application(self.first, arg_stepped)) case None: assert isinstance(self.first, MonoFunc) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed" return self.first.try_apply(self.arg) raise Exception('Unreachable') def __repr__(self) -> str: return f'[ {repr(self.first)}, {repr(self.arg)} ]' def codegen(self) -> str: if isinstance(self.first, MonoFunc | Builtin) and self.arg.is_value(): return unwrap_opt(self.first.try_apply(self.arg)).codegen() else: match self.first: case Application(Builtin('+', _, _), addend1): return f'({addend1.codegen()} + {self.arg.codegen()})' case Builtin('S', _, _): return f'(1+{self.arg.codegen()})' case Builtin('pred', _, _): return f'({self.arg.codegen()}-1)' return f'({self.first.codegen()})({self.arg.codegen()})' @dataclass class Int: value: int def subst(self, expression: Expression, variable: str) -> Expression: return self def is_value(self) -> bool: return True def step(self) -> Option[Expression]: return None def __repr__(self) -> str: return str(self.value) def codegen(self) -> str: return str(self.value) @dataclass class Variable: name: str def subst(self, expression: Expression, variable: str) -> Expression: if variable == self.name: return expression else: return self def is_value(self) -> bool: return False def step(self) -> Option[Expression]: match self.name: case '+': return Some(Builtin.PLUS()) case 'S': return Some(Builtin.S()) return None def __repr__(self) -> str: return '"' + repr(self.name)[1:-1] + '"' def codegen(self) -> str: return self.name @dataclass class Switch: branches: Mapping[int, Expression] fallback: Expression switching_on: Expression def subst(self, expression: Expression, variable: str) -> Expression: return Switch( {i: e.subst(expression, variable) for i, e in self.branches.items()}, self.fallback.subst(expression, variable), self.switching_on.subst(expression, variable)) def is_value(self) -> bool: return False def step(self) -> Option[Expression]: match self.switching_on.step(): case Some(switch_expr_stepped): return Some(Switch(self.branches, self.fallback, switch_expr_stepped)) case None: match self.switching_on: case Int(n): if n in self.branches: return Some(self.branches[n]) else: return Some(self.fallback) raise Exception('Attempted to switch on non-integer value') raise Exception('Unreachable') def __repr__(self) -> str: return '{ ' + ', '.join(f'{n}: ' + repr(e) for (n, e) in self.branches.items()) + f', _: {repr(self.fallback)}' + ' }' def codegen(self) -> str: switching_on_code = self.switching_on.codegen() return ':'.join( f'{switching_on_code}=={val}?({branch.codegen()})' for val, branch in self.branches.items() ) + f':{self.fallback.codegen()}' def compile_tree(tree: 'MatchTree[Expression]', match_against: Expression) -> Result[Expression, MatchException]: match tree: case LeafNode([match]): return Ok(match) case LeafNode([]): return Err(MatchException.Incomplete) case LeafNode([a, b, *rest]): return Err(MatchException.Ambiguous) case IntNode(location, specific_trees, fallback_tree): access_location = location_to_ir(location)(match_against) match sequence(tuple(compile_tree(tree, match_against) for tree in specific_trees.values())): case Err(e): return Err(e) case Ok(exprs): match compile_tree(fallback_tree, match_against): case Err(e): return Err(e) case Ok(fallback): return Ok(Switch(dict(zip(specific_trees.keys(), exprs)), fallback, match_against)) raise Exception('Unreachable') def location_to_ir(location: StructurePath) -> Callable[[Expression], Expression]: def access_location(part: int) -> Callable[[Expression], Expression]: def remove(expr: Expression) -> Expression: return Application(Builtin(f'pred', Builtin._PLUS_CONST(-1), f'$=>$-1'), expr) def access_location_prime(expr: Expression) -> Expression: if part < 1: return remove(expr) else: raise AssertionError('A!') return access_location_prime match location: case []: return lambda o: o case [part, *rest_location]: return c(location_to_ir(StructurePath(rest_location)), access_location(part)) raise Exception('Unreachable') def bindings_to_lets(bindings: Collection[Tuple[str, StructurePath]], deconstructing_term: Expression, body_expr: Expression) -> Expression: match bindings: case []: return body_expr case [(binding_name, location), *rest]: return LetBinding(binding_name, location_to_ir(location)(deconstructing_term), bindings_to_lets(rest, deconstructing_term, body_expr)) raise Exception('Unreachable') def subst_all(bindings: Sequence[Tuple[str, Expression]], body: Expression) -> Expression: match bindings: case []: return body case [(var, replacement), *rest]: return subst_all(rest, body.subst(replacement, var)) raise Exception('Unreachable') def count_uses(variable: str, expression: Expression) -> int: match expression: case MonoFunc(arg, body): return 0 if arg == variable else count_uses(variable, body) case Application(first, arg): return count_uses(variable, first) + count_uses(variable, arg) case Int(_): return 0 case Variable(name): return 1 if name == variable else 0 case Builtin(_, _): return 0 case LetBinding(lhs, rhs, body): return count_uses(variable, rhs) + count_uses(variable, body) case ReplHole(_, _): return 0 case Switch(branches, fallback, switching_on): return ( count_uses(variable, switching_on) + count_uses(variable, fallback) + sum(count_uses(variable, branch) for branch in branches.values()))