348 lines
9.8 KiB
Python
348 lines
9.8 KiB
Python
from emis_funky_funktions import *
|
|
|
|
from typing import Collection, Mapping, Sequence, Tuple, TypeAlias
|
|
from functools import reduce
|
|
from match_tree import MatchException, StructurePath, LeafNode, merge_all_trees, IntNode
|
|
|
|
import types_
|
|
|
|
|
|
Expression: TypeAlias = 'Function | Application | Int | Variable | Builtin | LetBinding | ReplHole | Switch'
|
|
Value: TypeAlias = 'Function | Int | Builtin | ReplHole'
|
|
|
|
@dataclass(frozen=True)
|
|
class ReplHole:
|
|
typ_bindings: types_.Context
|
|
val_bindings: Sequence[Tuple[str, Expression]] = tuple()
|
|
|
|
def subst(self, expression: Expression, variable: str) -> Expression:
|
|
return ReplHole(self.typ_bindings, (*self.val_bindings, (variable, expression)))
|
|
|
|
def is_value(self) -> bool:
|
|
return True
|
|
|
|
def step(self) -> Option[Expression]:
|
|
return None
|
|
|
|
def __repr__(self) -> str:
|
|
return "[]"
|
|
|
|
def codegen(self) -> str:
|
|
return '[]'
|
|
|
|
def render(self) -> str:
|
|
return '\n'.join(
|
|
f'{var_name} = ({var_expr.codegen()});'
|
|
for (var_name, var_expr) in self.val_bindings
|
|
)
|
|
|
|
@dataclass(frozen=True)
|
|
class Builtin:
|
|
name: str
|
|
f: Callable[[Expression], Option[Expression]]
|
|
js: str
|
|
|
|
def subst(self, expression: Expression, variable: str) -> Expression:
|
|
return self
|
|
|
|
def is_value(self) -> bool:
|
|
return True
|
|
|
|
def step(self) -> Option[Expression]:
|
|
return None
|
|
|
|
def try_apply(self, v: Expression) -> Option[Expression]:
|
|
return self.f(v)
|
|
|
|
def __repr__(self) -> str:
|
|
return "'" + repr(self.name)[1:-1] + "'"
|
|
|
|
def codegen(self) -> str:
|
|
return self.js
|
|
|
|
@cur2
|
|
@staticmethod
|
|
def _PLUS_CONST(i: int, e: Expression) -> Option[Expression]:
|
|
match e:
|
|
case Int(v):
|
|
return Some(Int(i + v))
|
|
return None
|
|
|
|
@staticmethod
|
|
def _PLUS(e: Expression) -> Option[Expression]:
|
|
match e:
|
|
case Int(v):
|
|
return Some(Builtin(f'+{v}', Builtin._PLUS_CONST(v), f'(x => x + {v})'))
|
|
return None
|
|
|
|
@staticmethod
|
|
def PLUS() -> 'Builtin':
|
|
return Builtin('+', Builtin._PLUS, '(x => y => x + y)')
|
|
|
|
@staticmethod
|
|
def S() -> 'Builtin':
|
|
return Builtin('S', Builtin._PLUS_CONST(1), '(x => x + 1)')
|
|
|
|
BUILTIN_SUBSTITUTIONS: Sequence[Tuple[str, Expression]] = (
|
|
('+', Builtin.PLUS()),
|
|
('S', Builtin.S()),
|
|
)
|
|
|
|
@dataclass(frozen=True)
|
|
class Function:
|
|
forms: 'Sequence[Tuple[patterns.Pattern, Expression]]'
|
|
|
|
def subst(self, expression: Expression, variable: str) -> Expression:
|
|
return Function([
|
|
(p, e if p.binds(variable) else e.subst(expression, variable))
|
|
for (p, e) in self.forms
|
|
])
|
|
|
|
def is_value(self) -> bool:
|
|
return True
|
|
|
|
def step(self) -> Option[Expression]:
|
|
return None
|
|
|
|
def function_to_tree(self) -> 'Result[Expression, MatchException]':
|
|
subtrees = (patt.match_tree([], LeafNode.from_value(expr)) for (patt, expr) in self.forms)
|
|
return reduce(merge_trees, subtrees, LeafNode(tuple()))
|
|
|
|
def eliminate(self, v: Expression) -> Result[Expression, MatchException]:
|
|
match_trees = tuple(pattern.match_tree([], LeafNode.from_value(bindings_to_lets(pattern.bindings(), v, body))) for (pattern, body) in self.forms)
|
|
unified_match_tree = merge_all_trees(match_trees)
|
|
return compile_tree(unified_match_tree, v)
|
|
|
|
def try_apply(self, v: Expression) -> Option[Expression]:
|
|
return hush(self.eliminate(v))
|
|
|
|
def codegen_inner(self) -> str:
|
|
return unwrap_r(self.eliminate(Variable('$'))).codegen()
|
|
def codegen(self) -> str:
|
|
return '$=>' + self.codegen_inner()
|
|
def codegen_named(self, name) -> str:
|
|
return f'function {name}($){{return {self.codegen_inner()}}}'
|
|
def __repr__(self) -> str:
|
|
return '{ ' + ', '.join('"' + repr(repr(p))[1:-1] + '" : ' + repr(e) for (p, e) in self.forms) + ' }'
|
|
|
|
@dataclass
|
|
class LetBinding:
|
|
lhs: str
|
|
rhs: Expression
|
|
body: Expression
|
|
|
|
def subst(self, expression: Expression, variable: str) -> Expression:
|
|
if self.lhs == variable:
|
|
return self
|
|
else:
|
|
return LetBinding(
|
|
self.lhs,
|
|
self.rhs.subst(expression, variable),
|
|
self.body.subst(expression, variable)
|
|
)
|
|
|
|
def is_value(self) -> bool:
|
|
return False
|
|
|
|
def step(self) -> Option[Expression]:
|
|
if self.rhs.is_value():
|
|
return Some(self.body.subst(
|
|
self.rhs.subst(
|
|
LetBinding(self.lhs, self.rhs, Variable(self.lhs)),
|
|
self.lhs
|
|
),
|
|
self.lhs
|
|
))
|
|
else:
|
|
return map_opt(lambda rhs_step:
|
|
LetBinding(self.lhs, rhs_step, self.body),
|
|
self.rhs.step()
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
return f'( {repr(self.lhs)}, {repr(self.rhs)}, {repr(self.body)} )'
|
|
|
|
def codegen(self) -> str:
|
|
rhs_cg = self.rhs.codegen_named(self.lhs) if isinstance(self.rhs, Function) else self.rhs.codegen()
|
|
return f'(({self.lhs}) => {self.body.codegen()})({rhs_cg})'
|
|
|
|
@dataclass
|
|
class Application:
|
|
first: Expression
|
|
arg: Expression
|
|
|
|
def subst(self, expression: Expression, variable: str) -> Expression:
|
|
return Application(
|
|
self.first.subst(expression, variable),
|
|
self.arg.subst(expression, variable)
|
|
)
|
|
|
|
def is_value(self) -> bool:
|
|
return False
|
|
|
|
def step(self) -> Option[Expression]:
|
|
match self.first.step():
|
|
case Some(first_stepped):
|
|
return Some(Application(first_stepped, self.arg))
|
|
case None:
|
|
match self.arg.step():
|
|
case Some(arg_stepped):
|
|
return Some(Application(self.first, arg_stepped))
|
|
case None:
|
|
assert isinstance(self.first, Function) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed"
|
|
return self.first.try_apply(self.arg)
|
|
raise Exception('Unreachable')
|
|
|
|
def __repr__(self) -> str:
|
|
return f'[ {repr(self.first)}, {repr(self.arg)} ]'
|
|
|
|
def codegen(self) -> str:
|
|
if isinstance(self.first, Function | Builtin) and self.arg.is_value():
|
|
return unwrap_opt(self.first.try_apply(self.arg)).codegen()
|
|
else:
|
|
return f'({self.first.codegen()})({self.arg.codegen()})'
|
|
|
|
@dataclass
|
|
class Int:
|
|
value: int
|
|
|
|
def subst(self, expression: Expression, variable: str) -> Expression:
|
|
return self
|
|
|
|
def is_value(self) -> bool:
|
|
return True
|
|
|
|
def step(self) -> Option[Expression]:
|
|
return None
|
|
|
|
def __repr__(self) -> str:
|
|
return str(self.value)
|
|
|
|
def codegen(self) -> str:
|
|
return str(self.value)
|
|
|
|
@dataclass
|
|
class Variable:
|
|
name: str
|
|
|
|
def subst(self, expression: Expression, variable: str) -> Expression:
|
|
if variable == self.name:
|
|
return expression
|
|
else:
|
|
return self
|
|
|
|
def is_value(self) -> bool:
|
|
return False
|
|
|
|
def step(self) -> Option[Expression]:
|
|
match self.name:
|
|
case '+':
|
|
return Some(Builtin.PLUS())
|
|
case 'S':
|
|
return Some(Builtin.S())
|
|
return None
|
|
|
|
def __repr__(self) -> str:
|
|
return '"' + repr(self.name)[1:-1] + '"'
|
|
|
|
def codegen(self) -> str:
|
|
return self.name
|
|
|
|
@dataclass
|
|
class Switch:
|
|
branches: Mapping[int, Expression]
|
|
fallback: Expression
|
|
switching_on: Expression
|
|
|
|
def subst(self, expression: Expression, variable: str) -> Expression:
|
|
return Switch(
|
|
{i: e.subst(expression, variable) for i, e in self.branches.items()},
|
|
self.switching_on.subst(expression, variable))
|
|
|
|
def is_value(self) -> bool:
|
|
return False
|
|
|
|
def step(self) -> Option[Expression]:
|
|
match self.switching_on.step():
|
|
case Some(switch_expr_stepped):
|
|
return Switch(self.branches, switch_expr_stepped)
|
|
case None:
|
|
match self.switching_on:
|
|
case Int(n):
|
|
if n in self.branches:
|
|
return Some(self.branches[n])
|
|
else:
|
|
return None
|
|
raise Exception('Attempted to switch on non-integer value')
|
|
raise Exception('Unreachable')
|
|
|
|
def __repr__(self) -> str:
|
|
return '{ ' + ', '.join(f'{n}: ' + repr(e) for (n, e) in self.branches.items()) + f', _: {repr(self.fallback)}' + ' }'
|
|
|
|
def codegen(self) -> str:
|
|
switching_on_code = self.switching_on.codegen()
|
|
return ':'.join(
|
|
f'{switching_on_code}==={val}?({branch.codegen()})'
|
|
for val, branch in self.branches.items()
|
|
) + f':{self.fallback.codegen()}'
|
|
|
|
def compile_tree(tree: 'MatchTree[Expression]', match_against: Expression) -> Result[Expression, MatchException]:
|
|
|
|
def match_int_builtin(lookup: Callable[[int], Expression]) -> Callable[[Expression], Option[Expression]]:
|
|
def match_inner(i: Expression) -> Expression:
|
|
match i:
|
|
case Int(value):
|
|
return lookup(value)
|
|
raise Exception('Bad type! Eep!')
|
|
return match_inner
|
|
|
|
match tree:
|
|
case LeafNode([match]):
|
|
return Ok(match)
|
|
case LeafNode([]):
|
|
return Err(MatchException.Incomplete)
|
|
case LeafNode([a, b, *rest]):
|
|
return Err(MatchException.Ambiguous)
|
|
case IntNode(location, specific_trees, fallback_tree):
|
|
access_location = location_to_ir(location)(match_against)
|
|
match sequence(tuple(compile_tree(tree, match_against) for tree in specific_trees.values())):
|
|
case Err(e):
|
|
return Err(e)
|
|
case Ok(exprs):
|
|
match compile_tree(fallback_tree, match_against):
|
|
case Err(e):
|
|
return Err(e)
|
|
case Ok(fallback):
|
|
return Ok(Switch(dict(zip(specific_trees.keys(), exprs)), fallback, match_against))
|
|
|
|
def location_to_ir(location: StructurePath) -> Callable[[Expression], Expression]:
|
|
def access_location(part: int) -> Callable[[Expression], Expression]:
|
|
def remove(expr: Expression) -> Expression:
|
|
return Application(Builtin(f'{part}', Builtin._PLUS_CONST(-1), f'$=>$-1'), expr)
|
|
def access_location_prime(expr: Expression) -> Expression:
|
|
if part < 1:
|
|
return remove(expr)
|
|
else:
|
|
raise AssertionError('A!')
|
|
return access_location_prime
|
|
match location:
|
|
case []:
|
|
return lambda o: o
|
|
case [part, *rest_location]:
|
|
return c(location_to_ir(StructurePath(rest_location)), access_location(part))
|
|
raise Exception('Unreachable')
|
|
|
|
def bindings_to_lets(bindings: Sequence[Tuple[str, StructurePath]], deconstructing_term: Expression, body_expr: Expression) -> Expression:
|
|
match bindings:
|
|
case []:
|
|
return body_expr
|
|
case [(binding_name, location), *rest]:
|
|
return LetBinding(binding_name, location_to_ir(location)(deconstructing_term), bindings_to_lets(rest, deconstructing_term, body_expr))
|
|
|
|
def subst_all(bindings: Sequence[Tuple[str, Expression]], body: Expression) -> Expression:
|
|
match bindings:
|
|
case []:
|
|
return body
|
|
case [(var, replacement), *rest]:
|
|
return subst_all(rest, body.subst(replacement, var))
|
|
raise Exception('Unreachable') |