JSON-Lang/ir.py

408 lines
12 KiB
Python

from emis_funky_funktions import *
from typing import Collection, Mapping, Sequence, Tuple, TypeAlias
from functools import reduce
from match_tree import MatchTree, MatchException, StructurePath, LeafNode, merge_all_trees, IntNode, EMPTY_STRUCT_PATH, FAIL_NODE
from patterns import Pattern
import types_
Expression: TypeAlias = 'MonoFunc | Application | Int | Variable | Builtin | LetBinding | ReplHole | Switch'
Value: TypeAlias = 'MonoFunc | Int | Builtin | ReplHole'
dollar_count: int = 0
def mk_dollar() -> str:
global dollar_count
dollar_count += 1
return f'${dollar_count-1}'
@dataclass(frozen=True)
class ReplHole:
typ_bindings: types_.Context
val_bindings: Sequence[Tuple[str, Expression]] = tuple()
def subst(self, expression: Expression, variable: str) -> Expression:
return ReplHole(self.typ_bindings, (*self.val_bindings, (variable, expression)))
def is_value(self) -> bool:
return True
def step(self) -> Option[Expression]:
return None
def __repr__(self) -> str:
return "[]"
def codegen(self) -> str:
return '[]'
def render(self) -> str:
def render_var(var_name: str, var_expr: Expression) -> str:
match var_expr:
case LetBinding(prev_func_name, MonoFunc(formal, func_body), Variable(mb_prev_func_name)):
if mb_prev_func_name == prev_func_name:
func_body_with_new_name = func_body.subst(Variable(var_name), prev_func_name)
return f'function {var_name}({formal}) {{return {func_body_with_new_name.codegen()}}}'
else:
raise Exception('This useless LetBinding should have been removed by the optimization stage')
case _:
return f'const {var_name} = ({var_expr.codegen()});'
return '\n'.join(
render_var(var_name, var_expr)
for (var_name, var_expr) in self.val_bindings[::-1]
if var_name not in types_.BUILTINS_CONTEXT
)
BuiltinBehavior: TypeAlias = 'Builtin.BB_PLUS_CONST | Builtin.BB_PLUS'
@dataclass(frozen=True)
class Builtin:
behavior: BuiltinBehavior
@dataclass(frozen=True)
class BB_PLUS_CONST:
amt: int
def name(self) -> str:
return f'{self.amt:+}'
def js(self) -> str:
return f'(x=>x{self.amt:+})'
def run(self, e: Expression) -> Option[Expression]:
return Some(Int(e.value + self.amt)) if isinstance(e, Int) else None
@dataclass(frozen=True)
class BB_PLUS:
def name(self) -> str:
return '+'
def js(self) -> str:
return '(x=>y=>x+y)'
def run(self, e: Expression) -> Option[Expression]:
return Some(Builtin(Builtin.BB_PLUS_CONST(e.value))) if isinstance(e, Int) else None
def subst(self, expression: Expression, variable: str) -> Expression:
return self
def is_value(self) -> bool:
return True
def step(self) -> Option[Expression]:
return None
def try_apply(self, v: Expression) -> Option[Expression]:
return self.behavior.run(v)
def __repr__(self) -> str:
return "'" + repr(self.behavior.name())[1:-1] + "'"
def codegen(self) -> str:
return self.behavior.js()
PLUS: 'Callable[[], Builtin]' = lambda: Builtin(Builtin.BB_PLUS())
S: 'Callable[[], Builtin]' = lambda: Builtin(Builtin.BB_PLUS_CONST(1))
BUILTIN_SUBSTITUTIONS: Sequence[Tuple[str, Expression]] = (
('+', Builtin.PLUS()),
('S', Builtin.S()),
)
@dataclass(frozen=True)
class MonoFunc:
arg: str
body: Expression
def subst(self, expression: Expression, variable: str) -> Expression:
if variable == self.arg:
return self
else:
return MonoFunc(self.arg, self.body.subst(expression, variable))
def is_value(self) -> bool:
return True
def step(self) -> Option[Expression]:
return None
@staticmethod
def from_match_function(forms: 'Sequence[Tuple[Pattern, Expression]]') -> Result[Expression, MatchException]:
# In certain cases, starting a function with a full match tree may be unnecessary.
# Specifically, if there exists only one possible branch and that branch binds only
# one value and that value is equal to the whole entire input, rather than assigning
# that input to a new variable, we may simply use argument variable instead.
match forms:
case [(patt, body)]: # A single possible branch
match patt.bindings():
case []: # Binds nothing
return Ok(MonoFunc('_', body))
case [(var, [])]: # Binds a single variable to the entire input
return Ok(MonoFunc(var, body))
local_variable = mk_dollar()
# If those special cases fail, we eliminate the pattern matching to produce a
# single body:
match_trees = tuple( # Construct a match tree for each possible branch
pattern.match_tree(
EMPTY_STRUCT_PATH,
LeafNode.from_value(bindings_to_lets(pattern.bindings(), Variable(local_variable), body))
)
for (pattern, body) in forms
)
unified_match_tree = merge_all_trees(match_trees) # Unify all the trees
compiled_tree = compile_tree(unified_match_tree, Variable(local_variable)) # Turn each tree into IR
return compiled_tree <= p(MonoFunc, local_variable)
def try_apply(self, v: Expression) -> Option[Expression]:
return Some(self.body.subst(v, self.arg))
def codegen(self) -> str:
return f'({self.arg}=>{self.body.codegen()})'
def codegen_named(self, name) -> str:
return f'(function {name}({self.arg}){{return {self.body.codegen()}}})'
def __repr__(self) -> str:
return f'{{{repr(self.arg)}: {repr(self.body)}}}'
@dataclass
class LetBinding:
lhs: str
rhs: Expression
body: Expression
def subst(self, expression: Expression, variable: str) -> Expression:
if self.lhs == variable:
return self
else:
return LetBinding(
self.lhs,
self.rhs.subst(expression, variable),
self.body.subst(expression, variable)
)
def is_value(self) -> bool:
return False
def step(self) -> Option[Expression]:
if self.rhs.is_value():
return Some(self.body.subst(
self.rhs.subst(
LetBinding(self.lhs, self.rhs, Variable(self.lhs)),
self.lhs
),
self.lhs
))
else:
return map_opt(lambda rhs_step:
LetBinding(self.lhs, rhs_step, self.body),
self.rhs.step()
)
def __repr__(self) -> str:
return f'( "{self.lhs}", {repr(self.rhs)}, {repr(self.body)} )'
def codegen(self) -> str:
rhs_cg = self.rhs.codegen_named(self.lhs) if isinstance(self.rhs, MonoFunc) else self.rhs.codegen()
if self.body == Variable(self.lhs):
return rhs_cg
else:
return f'({self.lhs}=>{self.body.codegen()})({rhs_cg})'
@dataclass
class Application:
first: Expression
arg: Expression
def subst(self, expression: Expression, variable: str) -> Expression:
return Application(
self.first.subst(expression, variable),
self.arg.subst(expression, variable)
)
def is_value(self) -> bool:
return False
def step(self) -> Option[Expression]:
match self.first.step():
case Some(first_stepped):
return Some(Application(first_stepped, self.arg))
case None:
match self.arg.step():
case Some(arg_stepped):
return Some(Application(self.first, arg_stepped))
case None:
assert isinstance(self.first, MonoFunc) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed"
return self.first.try_apply(self.arg)
raise Exception('Unreachable')
def __repr__(self) -> str:
return f'[ {repr(self.first)}, {repr(self.arg)} ]'
def codegen(self) -> str:
if isinstance(self.first, MonoFunc | Builtin) and self.arg.is_value():
return unwrap_opt(self.first.try_apply(self.arg)).codegen()
else:
match self.first:
case Application(Builtin(Builtin.BB_PLUS()), addend1):
return f'({addend1.codegen()} + {self.arg.codegen()})'
case Builtin(Builtin.BB_PLUS_CONST(n)):
return f'({self.arg.codegen()}{n:+})'
return f'({self.first.codegen()})({self.arg.codegen()})'
@dataclass
class Int:
value: int
def subst(self, expression: Expression, variable: str) -> Expression:
return self
def is_value(self) -> bool:
return True
def step(self) -> Option[Expression]:
return None
def __repr__(self) -> str:
return str(self.value)
def codegen(self) -> str:
return str(self.value)
@dataclass
class Variable:
name: str
def subst(self, expression: Expression, variable: str) -> Expression:
if variable == self.name:
return expression
else:
return self
def is_value(self) -> bool:
return False
def step(self) -> Option[Expression]:
match self.name:
case '+':
return Some(Builtin.PLUS())
case 'S':
return Some(Builtin.S())
return None
def __repr__(self) -> str:
return '"' + repr(self.name)[1:-1] + '"'
def codegen(self) -> str:
return self.name
@dataclass
class Switch:
branches: Mapping[int, Expression]
fallback: Expression
switching_on: Expression
def subst(self, expression: Expression, variable: str) -> Expression:
return Switch(
{i: e.subst(expression, variable) for i, e in self.branches.items()},
self.fallback.subst(expression, variable),
self.switching_on.subst(expression, variable))
def is_value(self) -> bool:
return False
def step(self) -> Option[Expression]:
match self.switching_on.step():
case Some(switch_expr_stepped):
return Some(Switch(self.branches, self.fallback, switch_expr_stepped))
case None:
match self.switching_on:
case Int(n):
if n in self.branches:
return Some(self.branches[n])
else:
return Some(self.fallback)
raise Exception('Attempted to switch on non-integer value')
raise Exception('Unreachable')
def __repr__(self) -> str:
return '{ ' + ', '.join(f'{n}: ' + repr(e) for (n, e) in self.branches.items()) + f', _: {repr(self.fallback)}' + ' }'
def codegen(self) -> str:
switching_on_code = self.switching_on.codegen()
return ':'.join(
f'{switching_on_code}=={val}?({branch.codegen()})'
for val, branch in self.branches.items()
) + f':{self.fallback.codegen()}'
def compile_tree(tree: 'MatchTree[Expression]', match_against: Expression) -> Result[Expression, MatchException]:
match tree:
case LeafNode([match]):
return Ok(match)
case LeafNode([]):
return Err(MatchException.Incomplete)
case LeafNode([a, b, *rest]):
return Err(MatchException.Ambiguous)
case IntNode(location, specific_trees, fallback_tree):
access_location = location_to_ir(location)(match_against)
match sequence(tuple(compile_tree(tree, match_against) for tree in specific_trees.values())):
case Err(e):
return Err(e)
case Ok(exprs):
match compile_tree(fallback_tree, match_against):
case Err(e):
return Err(e)
case Ok(fallback):
return Ok(Switch(dict(zip(specific_trees.keys(), exprs)), fallback, match_against))
raise Exception('Unreachable')
def location_to_ir(location: StructurePath) -> Callable[[Expression], Expression]:
def access_location(part: int) -> Callable[[Expression], Expression]:
def remove(expr: Expression) -> Expression:
return Application(Builtin(Builtin.BB_PLUS_CONST(-1)), expr)
def access_location_prime(expr: Expression) -> Expression:
if part < 1:
return remove(expr)
else:
raise AssertionError('A!')
return access_location_prime
match location:
case []:
return lambda o: o
case [part, *rest_location]:
return c(location_to_ir(StructurePath(rest_location)), access_location(part))
raise Exception('Unreachable')
def bindings_to_lets(bindings: Collection[Tuple[str, StructurePath]], deconstructing_term: Expression, body_expr: Expression) -> Expression:
match bindings:
case []:
return body_expr
case [(binding_name, location), *rest]:
return LetBinding(binding_name, location_to_ir(location)(deconstructing_term), bindings_to_lets(rest, deconstructing_term, body_expr))
raise Exception('Unreachable')
def subst_all(bindings: Sequence[Tuple[str, Expression]], body: Expression) -> Expression:
match bindings:
case []:
return body
case [(var, replacement), *rest]:
return subst_all(rest, body.subst(replacement, var))
raise Exception('Unreachable')
def count_uses(variable: str, expression: Expression) -> int:
match expression:
case MonoFunc(arg, body):
return 0 if arg == variable else count_uses(variable, body)
case Application(first, arg):
return count_uses(variable, first) + count_uses(variable, arg)
case Int(_):
return 0
case Variable(name):
return 1 if name == variable else 0
case Builtin(_, _):
return 0
case LetBinding(lhs, rhs, body):
return count_uses(variable, rhs) + count_uses(variable, body)
case ReplHole(_, _):
return 0
case Switch(branches, fallback, switching_on):
return (
count_uses(variable, switching_on) +
count_uses(variable, fallback) +
sum(count_uses(variable, branch) for branch in branches.values()))