JSON-Lang/ir.py

377 lines
11 KiB
Python

from emis_funky_funktions import *
from typing import Collection, Mapping, Sequence, Tuple, TypeAlias
from functools import reduce
from match_tree import MatchTree, MatchException, StructurePath, LeafNode, merge_all_trees, IntNode, EMPTY_STRUCT_PATH, FAIL_NODE
from patterns import Pattern
import types_
Expression: TypeAlias = 'Function | Application | Int | Variable | Builtin | LetBinding | ReplHole | Switch'
Value: TypeAlias = 'Function | Int | Builtin | ReplHole'
@dataclass(frozen=True)
class ReplHole:
typ_bindings: types_.Context
val_bindings: Sequence[Tuple[str, Expression]] = tuple()
def subst(self, expression: Expression, variable: str) -> Expression:
return ReplHole(self.typ_bindings, (*self.val_bindings, (variable, expression)))
def is_value(self) -> bool:
return True
def step(self) -> Option[Expression]:
return None
def __repr__(self) -> str:
return "[]"
def codegen(self) -> str:
return '[]'
def render(self) -> str:
return '\n'.join(
f'const {var_name} = ({var_expr.codegen()});'
for (var_name, var_expr) in self.val_bindings
if var_name not in types_.BUILTINS_CONTEXT
)
@dataclass(frozen=True)
class Builtin:
name: str
f: Callable[[Expression], Option[Expression]]
js: str
def subst(self, expression: Expression, variable: str) -> Expression:
return self
def is_value(self) -> bool:
return True
def step(self) -> Option[Expression]:
return None
def try_apply(self, v: Expression) -> Option[Expression]:
return self.f(v)
def __repr__(self) -> str:
return "'" + repr(self.name)[1:-1] + "'"
def codegen(self) -> str:
return self.js
@cur2
@staticmethod
def _PLUS_CONST(i: int, e: Expression) -> Option[Expression]:
match e:
case Int(v):
return Some(Int(i + v))
return None
@staticmethod
def _PLUS(e: Expression) -> Option[Expression]:
match e:
case Int(v):
return Some(Builtin(f'+{v}', Builtin._PLUS_CONST(v), f'(x => x + {v})'))
return None
@staticmethod
def PLUS() -> 'Builtin':
return Builtin('+', Builtin._PLUS, '(x => y => x + y)')
@staticmethod
def S() -> 'Builtin':
return Builtin('S', Builtin._PLUS_CONST(1), '(x => x + 1)')
BUILTIN_SUBSTITUTIONS: Sequence[Tuple[str, Expression]] = (
('+', Builtin.PLUS()),
('S', Builtin.S()),
)
@dataclass(frozen=True)
class Function:
forms: 'Sequence[Tuple[Pattern, Expression]]'
def subst(self, expression: Expression, variable: str) -> Expression:
return Function([
(p, e if p.binds(variable) else e.subst(expression, variable))
for (p, e) in self.forms
])
def is_value(self) -> bool:
return True
def step(self) -> Option[Expression]:
return None
def eliminate(self, v: Expression) -> Result[Expression, MatchException]:
match_trees = tuple(pattern.match_tree(EMPTY_STRUCT_PATH, LeafNode.from_value(bindings_to_lets(pattern.bindings(), v, body))) for (pattern, body) in self.forms)
unified_match_tree = merge_all_trees(match_trees)
return compile_tree(unified_match_tree, v)
def try_apply(self, v: Expression) -> Option[Expression]:
return hush(self.eliminate(v))
def codegen_inner(self) -> str:
return unwrap_r(self.eliminate(Variable('$'))).codegen()
def try_codegen_sp(self) -> Option[Tuple[str, str]]:
""" A special-case of codegen inner (see description)
In certain cases, starting a function with a full match tree may be unnecessary.
Specifically, if there exists only one possible branch and that branch binds only
one value and that value is equal to the whole entire input, rather than assigning
that input to a new variable, we may simply use argument variable instead.
This method returns the generated code for the inner branch in such a case.
Additionally, the second string returned represents the name of the variable which
ought to be bound as the argument. If the argument is unused, this will be "$".
"""
match self.forms:
case [(patt, expr)]: # A single possible branch
match patt.bindings():
case []: # Binds nothing
return Some((expr.codegen(), '$'))
case [(var, [])]: # Binds a single variable to the entire input
return Some((expr.codegen(), var))
return None
def codegen(self) -> str:
match self.try_codegen_sp():
case Some((codegen, var)):
return f'{var}=>{codegen}'
case None:
return '$=>' + self.codegen_inner()
raise Exception('Unreachable')
def codegen_named(self, name) -> str:
match self.try_codegen_sp():
case Some((codegen, var)):
return f'function {name}({var}){{return {codegen}}}'
case None:
return f'function {name}($){{return {self.codegen_inner()}}}'
raise Exception('Unreachable')
def __repr__(self) -> str:
return '{ ' + ', '.join('"' + repr(repr(p))[1:-1] + '" : ' + repr(e) for (p, e) in self.forms) + ' }'
@dataclass
class LetBinding:
lhs: str
rhs: Expression
body: Expression
def subst(self, expression: Expression, variable: str) -> Expression:
if self.lhs == variable:
return self
else:
return LetBinding(
self.lhs,
self.rhs.subst(expression, variable),
self.body.subst(expression, variable)
)
def is_value(self) -> bool:
return False
def step(self) -> Option[Expression]:
if self.rhs.is_value():
return Some(self.body.subst(
self.rhs.subst(
LetBinding(self.lhs, self.rhs, Variable(self.lhs)),
self.lhs
),
self.lhs
))
else:
return map_opt(lambda rhs_step:
LetBinding(self.lhs, rhs_step, self.body),
self.rhs.step()
)
def __repr__(self) -> str:
return f'( {repr(self.lhs)}, {repr(self.rhs)}, {repr(self.body)} )'
def codegen(self) -> str:
rhs_cg = self.rhs.codegen_named(self.lhs) if isinstance(self.rhs, Function) else self.rhs.codegen()
return f'({self.lhs}=>{self.body.codegen()})({rhs_cg})'
@dataclass
class Application:
first: Expression
arg: Expression
def subst(self, expression: Expression, variable: str) -> Expression:
return Application(
self.first.subst(expression, variable),
self.arg.subst(expression, variable)
)
def is_value(self) -> bool:
return False
def step(self) -> Option[Expression]:
match self.first.step():
case Some(first_stepped):
return Some(Application(first_stepped, self.arg))
case None:
match self.arg.step():
case Some(arg_stepped):
return Some(Application(self.first, arg_stepped))
case None:
assert isinstance(self.first, Function) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed"
return self.first.try_apply(self.arg)
raise Exception('Unreachable')
def __repr__(self) -> str:
return f'[ {repr(self.first)}, {repr(self.arg)} ]'
def codegen(self) -> str:
if isinstance(self.first, Function | Builtin) and self.arg.is_value():
return unwrap_opt(self.first.try_apply(self.arg)).codegen()
else:
match self.first:
case Application(Builtin('+', _, _), addend1):
return f'({addend1.codegen()} + {self.arg.codegen()})'
case Builtin('S', _, _):
return f'(1+{self.arg.codegen()})'
case Builtin('pred', _, _):
return f'({self.arg.codegen()}-1)'
return f'({self.first.codegen()})({self.arg.codegen()})'
@dataclass
class Int:
value: int
def subst(self, expression: Expression, variable: str) -> Expression:
return self
def is_value(self) -> bool:
return True
def step(self) -> Option[Expression]:
return None
def __repr__(self) -> str:
return str(self.value)
def codegen(self) -> str:
return str(self.value)
@dataclass
class Variable:
name: str
def subst(self, expression: Expression, variable: str) -> Expression:
if variable == self.name:
return expression
else:
return self
def is_value(self) -> bool:
return False
def step(self) -> Option[Expression]:
match self.name:
case '+':
return Some(Builtin.PLUS())
case 'S':
return Some(Builtin.S())
return None
def __repr__(self) -> str:
return '"' + repr(self.name)[1:-1] + '"'
def codegen(self) -> str:
return self.name
@dataclass
class Switch:
branches: Mapping[int, Expression]
fallback: Expression
switching_on: Expression
def subst(self, expression: Expression, variable: str) -> Expression:
return Switch(
{i: e.subst(expression, variable) for i, e in self.branches.items()},
self.fallback,
self.switching_on.subst(expression, variable))
def is_value(self) -> bool:
return False
def step(self) -> Option[Expression]:
match self.switching_on.step():
case Some(switch_expr_stepped):
return Some(Switch(self.branches, self.fallback, switch_expr_stepped))
case None:
match self.switching_on:
case Int(n):
if n in self.branches:
return Some(self.branches[n])
else:
return Some(self.fallback)
raise Exception('Attempted to switch on non-integer value')
raise Exception('Unreachable')
def __repr__(self) -> str:
return '{ ' + ', '.join(f'{n}: ' + repr(e) for (n, e) in self.branches.items()) + f', _: {repr(self.fallback)}' + ' }'
def codegen(self) -> str:
switching_on_code = self.switching_on.codegen()
return ':'.join(
f'{switching_on_code}=={val}?({branch.codegen()})'
for val, branch in self.branches.items()
) + f':{self.fallback.codegen()}'
def compile_tree(tree: 'MatchTree[Expression]', match_against: Expression) -> Result[Expression, MatchException]:
match tree:
case LeafNode([match]):
return Ok(match)
case LeafNode([]):
return Err(MatchException.Incomplete)
case LeafNode([a, b, *rest]):
return Err(MatchException.Ambiguous)
case IntNode(location, specific_trees, fallback_tree):
access_location = location_to_ir(location)(match_against)
match sequence(tuple(compile_tree(tree, match_against) for tree in specific_trees.values())):
case Err(e):
return Err(e)
case Ok(exprs):
match compile_tree(fallback_tree, match_against):
case Err(e):
return Err(e)
case Ok(fallback):
return Ok(Switch(dict(zip(specific_trees.keys(), exprs)), fallback, match_against))
raise Exception('Unreachable')
def location_to_ir(location: StructurePath) -> Callable[[Expression], Expression]:
def access_location(part: int) -> Callable[[Expression], Expression]:
def remove(expr: Expression) -> Expression:
return Application(Builtin(f'pred', Builtin._PLUS_CONST(-1), f'$=>$-1'), expr)
def access_location_prime(expr: Expression) -> Expression:
if part < 1:
return remove(expr)
else:
raise AssertionError('A!')
return access_location_prime
match location:
case []:
return lambda o: o
case [part, *rest_location]:
return c(location_to_ir(StructurePath(rest_location)), access_location(part))
raise Exception('Unreachable')
def bindings_to_lets(bindings: Collection[Tuple[str, StructurePath]], deconstructing_term: Expression, body_expr: Expression) -> Expression:
match bindings:
case []:
return body_expr
case [(binding_name, location), *rest]:
return LetBinding(binding_name, location_to_ir(location)(deconstructing_term), bindings_to_lets(rest, deconstructing_term, body_expr))
raise Exception('Unreachable')
def subst_all(bindings: Sequence[Tuple[str, Expression]], body: Expression) -> Expression:
match bindings:
case []:
return body
case [(var, replacement), *rest]:
return subst_all(rest, body.subst(replacement, var))
raise Exception('Unreachable')