Implement compilation

This commit is contained in:
Emi Simpson 2024-03-15 09:34:54 -04:00
parent 673292e913
commit 88399fff69
Signed by: Emi
GPG key ID: A12F2C2FFDC3D847
8 changed files with 404 additions and 158 deletions

28
compile.py Normal file
View file

@ -0,0 +1,28 @@
from emis_funky_funktions import *
from typing import Collection, Sequence, TypeAlias
from ir import BUILTIN_SUBSTITUTIONS, Expression, ReplHole, subst_all
from genir import json_to_ir, PatternParseProblem, BranchTypesDiffer, UndefinedVariable
from types_ import BUILTINS_CONTEXT, UnificationError
from silly_thing import evaluate
import json
from dataclasses import dataclass
from operator import add
def main():
import sys
match sys.argv:
case [_, file]:
# TODO handle this
expr, ty, substs = unwrap_r(json_to_ir(json.loads(open(sys.argv[1]).read()), BUILTINS_CONTEXT))
result = evaluate(subst_all(BUILTIN_SUBSTITUTIONS, expr))
if isinstance(result, ReplHole):
print(result.render())
else:
print(result.codegen())
case _:
raise Exception('TODO')
if __name__ == '__main__':
main()

13
fibb.json Normal file
View file

@ -0,0 +1,13 @@
[
"fibb_helper",
{
"a": {"b": {
"0": "a",
"1": "b",
"S S n": ["fibb_helper", ["+", "a", "b"], ["+", ["+", "a", "b"], "b"], "n"]
}}
},
"fibb_helper",
0,
1
]

View file

@ -2,7 +2,8 @@ from emis_funky_funktions import *
from typing import Sequence, Mapping
from pattern import lex_and_parse_pattern
from ir import Expression, Function, Application, Int, Variable, LetBinding, ReplHole, Pattern, NamePattern, IgnorePattern, IntPattern, SPattern
from ir import Expression, Function, Application, Int, Variable, LetBinding, ReplHole
from patterns import Pattern, NamePattern, IgnorePattern, IntPattern, SPattern
from types_ import *
from functools import reduce

262
ir.py
View file

@ -1,137 +1,14 @@
from emis_funky_funktions import *
from typing import Collection, Mapping, Sequence, Tuple, TypeAlias
from functools import reduce
from match_tree import MatchException, StructurePath, LeafNode, merge_all_trees, IntNode
import types_
Expression: TypeAlias = 'Function | Application | Int | Variable | Builtin | LetBinding | ReplHole'
Pattern: TypeAlias = 'NamePattern | IntPattern | SPattern | IgnorePattern'
@dataclass(frozen=True)
class NamePattern:
"""
A pattern which always succeeds to match, and binds a whole expression to a name
"""
name: str
def binds(self, var: str) -> bool:
"""
Test to see if this pattern binds a given variable
"""
return var == self.name
def match(self, e: Expression) -> Option[Sequence[Tuple[str, Expression]]]:
"""
Match an expression against this pattern
>>> NamePattern('my_var').match(Int(1))
Some((('my_var', 1),))
"""
return Some(((self.name, e),))
def __repr__(self) -> str:
return self.name
def codegen(self, match_on: str) -> Tuple[Sequence[str], Collection[Tuple[str, str]]]:
return (tuple(), [(self.name, match_on)])
@dataclass(frozen=True)
class IgnorePattern:
"""
A pattern which always succeeds to match, but binds nothing
"""
def binds(self, var: str) -> bool:
"""
Test to see if this pattern binds a given variable
For an `IgnorePattern` this is always false
"""
return False
def match(self, e: Expression) -> Option[Sequence[Tuple[str, Expression]]]:
"""
Match an expression against this pattern
>>> IgnorePattern().match(Int(1))
Some(())
"""
return Some(tuple())
def __repr__(self) -> str:
return '_'
def codegen(self, match_on: str) -> Tuple[Sequence[str], Collection[Tuple[str, str]]]:
return (tuple(), tuple())
@dataclass(frozen=True)
class IntPattern:
value: int
def binds(self, var: str) -> bool:
"""
Test to see if this pattern binds a given variable
For an `IntPattern` this is always false
"""
return False
def match(self, e: Expression) -> Option[Sequence[Tuple[str, Expression]]]:
"""
Match an expression against this pattern
>>> IntPattern(2).match(Int(1)) is None
True
>>> IntPattern(1).match(Int(1))
Some(())
"""
match e:
case Int(v) if v == self.value:
return Some(tuple())
return None
def __repr__(self) -> str:
return repr(self.value)
def codegen(self, match_on: str) -> Tuple[Sequence[str], Collection[Tuple[str, str]]]:
return ((f'{match_on}=={self.value}',), tuple())
@dataclass(frozen=True)
class SPattern:
pred: Pattern
def binds(self, var: str) -> bool:
"""
Test to see if this pattern binds a given variable
"""
return self.pred.binds(var)
def match(self, e: Expression) -> Option[Sequence[Tuple[str, Expression]]]:
"""
Match an expression against this pattern
>>> SPattern(NamePattern('n')).match(Int(1))
Some((('n', 0),))
>>> SPattern(NamePattern('n')).match(Int(0)) is None
True
>>> SPattern(SPattern(NamePattern('n'))).match(Int(4))
Some((('n', 2),))
"""
match e:
case Int(v) if v > 0:
return self.pred.match(Int(v - 1))
return None
def __repr__(self) -> str:
return 'S ' + repr(self.pred)
def codegen(self, match_on: str) -> Tuple[Sequence[str], Collection[Tuple[str, str]]]:
pred_conditions, pred_bindings = self.pred.codegen(f'({match_on}-1)')
return ((f'{match_on}>0', *pred_conditions), pred_bindings)
Expression: TypeAlias = 'Function | Application | Int | Variable | Builtin | LetBinding | ReplHole | Switch'
Value: TypeAlias = 'Function | Int | Builtin | ReplHole'
@dataclass(frozen=True)
class ReplHole:
@ -213,7 +90,7 @@ BUILTIN_SUBSTITUTIONS: Sequence[Tuple[str, Expression]] = (
@dataclass(frozen=True)
class Function:
forms: Sequence[Tuple[Pattern, Expression]]
forms: 'Sequence[Tuple[patterns.Pattern, Expression]]'
def subst(self, expression: Expression, variable: str) -> Expression:
return Function([
@ -227,36 +104,20 @@ class Function:
def step(self) -> Option[Expression]:
return None
def function_to_tree(self) -> 'Result[Expression, MatchException]':
subtrees = (patt.match_tree([], LeafNode.from_value(expr)) for (patt, expr) in self.forms)
return reduce(merge_trees, subtrees, LeafNode(tuple()))
def eliminate(self, v: Expression) -> Result[Expression, MatchException]:
match_trees = tuple(pattern.match_tree([], LeafNode.from_value(bindings_to_lets(pattern.bindings(), v, body))) for (pattern, body) in self.forms)
unified_match_tree = merge_all_trees(match_trees)
return compile_tree(unified_match_tree, v)
def try_apply(self, v: Expression) -> Option[Expression]:
match tuple((bindings.val, body) for (pattern, body) in self.forms for bindings in (pattern.match(v),) if bindings is not None):
case []:
return None
case [(bindings, body), *rest]:
return Some(subst_all(bindings, body))
raise Exception('Unreachable')
return hush(self.eliminate(v))
def codegen_inner(self) -> str:
return (':'.join(
(
'&&'.join(
iter(pattern.codegen('$')[0])
) or '1'
) + '?' + (
'((' + ','.join(
binding_name
for (binding_name, binding_value)
in pattern.codegen('$')[1]
) + f')=>({branch.codegen()}))(' + ','.join(
binding_value
for (binding_name, binding_value)
in pattern.codegen('$')[1]
) + ')'
if len(pattern.codegen('$')[1]) else
branch.codegen()
)
for (pattern, branch) in self.forms
) or '0?[].b') + ':[].b'
return unwrap_r(self.eliminate(Variable('$'))).codegen()
def codegen(self) -> str:
return '$=>' + self.codegen_inner()
def codegen_named(self, name) -> str:
@ -387,6 +248,97 @@ class Variable:
def codegen(self) -> str:
return self.name
@dataclass
class Switch:
branches: Mapping[int, Expression]
fallback: Expression
switching_on: Expression
def subst(self, expression: Expression, variable: str) -> Expression:
return Switch(
{i: e.subst(expression, variable) for i, e in self.branches.items()},
self.switching_on.subst(expression, variable))
def is_value(self) -> bool:
return False
def step(self) -> Option[Expression]:
match self.switching_on.step():
case Some(switch_expr_stepped):
return Switch(self.branches, switch_expr_stepped)
case None:
match self.switching_on:
case Int(n):
if n in self.branches:
return Some(self.branches[n])
else:
return None
raise Exception('Attempted to switch on non-integer value')
raise Exception('Unreachable')
def __repr__(self) -> str:
return '{ ' + ', '.join(f'{n}: ' + repr(e) for (n, e) in self.branches.items()) + f', _: {repr(self.fallback)}' + ' }'
def codegen(self) -> str:
switching_on_code = self.switching_on.codegen()
return ':'.join(
f'{switching_on_code}==={val}?({branch.codegen()})'
for val, branch in self.branches.items()
) + f':{self.fallback.codegen()}'
def compile_tree(tree: 'MatchTree[Expression]', match_against: Expression) -> Result[Expression, MatchException]:
def match_int_builtin(lookup: Callable[[int], Expression]) -> Callable[[Expression], Option[Expression]]:
def match_inner(i: Expression) -> Expression:
match i:
case Int(value):
return lookup(value)
raise Exception('Bad type! Eep!')
return match_inner
match tree:
case LeafNode([match]):
return Ok(match)
case LeafNode([]):
return Err(MatchException.Incomplete)
case LeafNode([a, b, *rest]):
return Err(MatchException.Ambiguous)
case IntNode(location, specific_trees, fallback_tree):
access_location = location_to_ir(location)(match_against)
match sequence(tuple(compile_tree(tree, match_against) for tree in specific_trees.values())):
case Err(e):
return Err(e)
case Ok(exprs):
match compile_tree(fallback_tree, match_against):
case Err(e):
return Err(e)
case Ok(fallback):
return Ok(Switch(dict(zip(specific_trees.keys(), exprs)), fallback, match_against))
def location_to_ir(location: StructurePath) -> Callable[[Expression], Expression]:
def access_location(part: int) -> Callable[[Expression], Expression]:
def remove(expr: Expression) -> Expression:
return Application(Builtin(f'{part}', Builtin._PLUS_CONST(-1), f'$=>$-1'), expr)
def access_location_prime(expr: Expression) -> Expression:
if part < 1:
return remove(expr)
else:
raise AssertionError('A!')
return access_location_prime
match location:
case []:
return lambda o: o
case [part, *rest_location]:
return c(location_to_ir(StructurePath(rest_location)), access_location(part))
raise Exception('Unreachable')
def bindings_to_lets(bindings: Sequence[Tuple[str, StructurePath]], deconstructing_term: Expression, body_expr: Expression) -> Expression:
match bindings:
case []:
return body_expr
case [(binding_name, location), *rest]:
return LetBinding(binding_name, location_to_ir(location)(deconstructing_term), bindings_to_lets(rest, deconstructing_term, body_expr))
def subst_all(bindings: Sequence[Tuple[str, Expression]], body: Expression) -> Expression:
match bindings:
case []:

91
match_tree.py Normal file
View file

@ -0,0 +1,91 @@
from emis_funky_funktions import *
from collections.abc import Collection, Mapping, Sequence
from dataclasses import dataclass
from enum import auto, IntEnum
from typing import Callable, Collection, NewType, TypeAlias
from functools import reduce
StructurePath = NewType('StructurePath', 'Sequence[int]')
Bindings = NewType('Bindings', 'Mapping[str, StructurePath]')
PatternID = NewType('PatternID', 'int')
MatchTree: TypeAlias = 'LeafNode[A] | IntNode[A]'
class MatchException(IntEnum):
Incomplete = auto()
Ambiguous = auto()
@dataclass(frozen=True)
class LeafNode(Generic[A]):
matches: Collection[A]
@staticmethod
def from_value(value: A) -> 'LeafNode[A]':
return LeafNode((value,))
@staticmethod
def fail() -> 'LeafNode[Any]':
return LeafNode([])
def is_complete(self) -> bool:
return len(matches) > 0
def __repr__(self) -> str:
return '{..}' if len(self.matches) else 'X'
@dataclass(frozen=True)
class IntNode(Generic[A]):
location: StructurePath
specific_trees: 'Mapping[int, MatchTree[A]]'
fallback_tree: 'MatchTree[A]'
def merge_each_child(self, other: 'MatchTree[A]'):
return IntNode(
self.location,
{i: merge_trees(tree, other) for (i, tree) in self.specific_trees.items()},
merge_trees(self.fallback_tree, other)
)
def is_complete(self) -> bool:
return all(
subtree.is_complete() for subtree in self.specific_trees.values()
) and self.fallback_tree.is_complete()
def __repr__(self) -> str:
return repr(self.location) + '{' + ','.join(f'{i}: {repr(mt)}' for i, mt in self.specific_trees.items()) + f', _: {self.fallback_tree}' + '}'
def merge_trees(t1: 'MatchTree[A]', t2: 'MatchTree[A]') -> 'MatchTree[A]':
match (t1, t2):
case (IntNode(location1, specific_trees1, fallback_tree1) as tree1, IntNode(location2, specific_trees2, fallback_tree2) as tree2):
if location1 == location2:
return IntNode(
location1,
{
i:
(
merge_trees(specific_trees1[i], specific_trees2[i])
if i in specific_trees1 and i in specific_trees2 else
merge_trees(specific_trees1[i], fallback_tree2)
if i in specific_trees1 and i not in specific_trees2 else
merge_trees(fallback_tree1, specific_trees2[i])
)
for i in (*specific_trees1.keys(), *specific_trees2.keys())
},
merge_trees(fallback_tree1, fallback_tree2),
)
elif list(location1) < list(location2):
return tree1.merge_each_child(tree2)
else:
return tree2.merge_each_child(tree1)
case (IntNode() as int_node, match_node) | (match_node, IntNode() as int_node):
return int_node.merge_each_child(match_node)
case (LeafNode(matches1), LeafNode(matches2)):
return LeafNode((*matches1, *matches2))
raise Exception('Unreachable')
def merge_all_trees(trees: 'Iterable[MatchTree[A]]') -> 'MatchTree[A]':
return reduce(merge_trees, trees, LeafNode.fail())

58
parse2.py Normal file
View file

@ -0,0 +1,58 @@
from emis_funky_funktions import *
from typing import AbstractSet, FrozenSet, TypeAlias, TypeGuard, TypeVar
Lexeme = TypeVar('Lexeme')
Token = TypeVar('Token')
Variable = TypeVar('Variable')
Handle: TypeAlias = Sequence[Variable | Token]
Production: TypeAlias = Tuple[Variable, Handle[Variable, Token]]
Grammar: TypeAlias = Sequence[Production[Variable, Token]]
NfaState: TypeAlias = Tuple[int, int]
Nfa: TypeAlias = Callable[[NfaState, Variable | Token], FrozenSet[NfaState]]
DfaState: TypeAlias = FrozenSet(Tuple[int, int])
Dfa: TypeAlias = Callable[[DfaState, Variable | Token], FrozenSet[NfaState]]
def build_nfa(
is_var: Callable[[Variable | Token], TypeGuard[Variable]],
grammar: Grammar[Variable, Token],
) -> Nfa[Variable, Token]:
def epsilon_closure_step(state: NfaState) -> FrozenSet[NfaState]:
production_no, symbol_no = state
_, production = grammar[production_no]
next_symbol = production[symbol_no]
if is_var(next_symbol):
possible_productions: Iterator[NfaState] = ((i, 0) for i, (variable, handle) in enumerate(grammar) if variable == next_symbol)
return fset(state, *possible_productions)
else:
return fset(state,)
def epsilon_closure(states: FrozenSet[NfaState], previous_states: FrozenSet[NfaState] = fset()) -> FrozenSet[NfaState]:
new_states = FSet(new_state for old_state in states for new_state in epsilon_closure_step(old_state)) - previous_states - states
if len(new_states) == 0:
return states | previous_states
else:
return epsilon_closure(new_states, states | previous_states)
def nfa(state: Tuple[int, int], symbol: Variable | Token) -> FrozenSet[NfaState]:
production_no, symbol_no = state
production = grammar[production_no]
next_symbol = production[symbol_no]
if next_symbol == symbol:
return epsilon_closure(fset((production_no, symbol_no + 1)))
else:
return fset()
def dfa(dstate: DfaState, symbol: Variable | Token) -> DfaState:
return FSet(
new_nstate
for nstate in dstate
for new_nstate in nfa(nstate, symbol)
)
return nfa

View file

@ -2,7 +2,7 @@ from emis_funky_funktions import *
from typing import Collection, Mapping, Sequence, Tuple, TypeAlias
from comb_parse import Parser
from ir import Pattern, NamePattern, IgnorePattern, IntPattern, SPattern
from patterns import Pattern, NamePattern, IgnorePattern, IntPattern, SPattern
from lex import Lexeme, tokenize
from enum import auto, IntEnum
@ -53,4 +53,4 @@ def lex_and_parse_pattern(input: str) -> Result[Pattern, str | Mapping[Lexeme[Pa
return Err(e)
case Err(remainder):
return Err(remainder)
raise Exception('Unreachable')
raise Exception('Unreachable')

103
patterns.py Normal file
View file

@ -0,0 +1,103 @@
from emis_funky_funktions import *
from typing import Collection, Mapping, Sequence, Tuple, TypeAlias
import types_
from match_tree import LeafNode, IntNode, MatchTree, StructurePath
Pattern: TypeAlias = 'NamePattern | IntPattern | SPattern | IgnorePattern'
@dataclass(frozen=True)
class NamePattern:
"""
A pattern which always succeeds to match, and binds a whole expression to a name
"""
name: str
def match_tree(self, location: StructurePath, success_leaf: 'MatchTree[A]') -> 'MatchTree[A]':
return success_leaf
def bindings(self) -> Collection[Tuple[str, StructurePath]]:
return ((self.name, tuple()),)
def binds(self, var: str) -> bool:
"""
Test to see if this pattern binds a given variable
"""
return var == self.name
def __repr__(self) -> str:
return self.name
@dataclass(frozen=True)
class IgnorePattern:
"""
A pattern which always succeeds to match, but binds nothing
"""
def match_tree(self, location: StructurePath, success_leaf: 'MatchTree[A]') -> 'MatchTree[A]':
return success_leaf
def bindings(self) -> Collection[Tuple[str, StructurePath]]:
return tuple()
def binds(self, var: str) -> bool:
"""
Test to see if this pattern binds a given variable
For an `IgnorePattern` this is always false
"""
return False
def __repr__(self) -> str:
return '_'
@dataclass(frozen=True)
class IntPattern:
value: int
def match_tree(self, location: StructurePath, success_leaf: 'MatchTree[A]') -> 'MatchTree[A]':
return IntNode(location, {self.value: success_leaf}, LeafNode([]))
def bindings(self) -> Collection[Tuple[str, StructurePath]]:
return tuple()
def binds(self, var: str) -> bool:
"""
Test to see if this pattern binds a given variable
For an `IntPattern` this is always false
"""
return False
def __repr__(self) -> str:
return repr(self.value)
@dataclass(frozen=True)
class SPattern:
pred: Pattern
def match_tree(self, location: StructurePath, success_leaf: 'MatchTree[A]') -> 'MatchTree[A]':
match self.pred.match_tree(StructurePath((*location, 0)), success_leaf):
case IntNode(child_location, trees, fallback) as child:
# If the child check is also an int node on pred of our current location:
# "raise" that node to be at our current level.
if child_location == (*location, 0):
return IntNode(location, dict(((0, LeafNode.fail()), *((1 + v, mt) for v, mt in trees.items()))), fallback)
else:
return IntNode(location, {0: LeafNode.fail()}, child)
case child:
return IntNode(location, {0: LeafNode.fail()}, child)
def bindings(self) -> Collection[Tuple[str, StructurePath]]:
# Prepend each binding path of the child pattern with 0 (i.e. subtract one from an int)
return tuple((name, (0, *path)) for (name, path) in self.pred.bindings())
def binds(self, var: str) -> bool:
"""
Test to see if this pattern binds a given variable
"""
return self.pred.binds(var)
def __repr__(self) -> str:
return 'S ' + repr(self.pred)