JSON-Lang/match_tree.py

91 lines
2.8 KiB
Python
Raw Normal View History

2024-03-15 13:34:54 +00:00
from emis_funky_funktions import *
from collections.abc import Collection, Mapping, Sequence
from dataclasses import dataclass
from enum import auto, IntEnum
from typing import Callable, Collection, NewType, TypeAlias
from functools import reduce
StructurePath = NewType('StructurePath', 'Sequence[int]')
Bindings = NewType('Bindings', 'Mapping[str, StructurePath]')
PatternID = NewType('PatternID', 'int')
MatchTree: TypeAlias = 'LeafNode[A] | IntNode[A]'
class MatchException(IntEnum):
Incomplete = auto()
Ambiguous = auto()
@dataclass(frozen=True)
class LeafNode(Generic[A]):
matches: Collection[A]
@staticmethod
def from_value(value: A) -> 'LeafNode[A]':
return LeafNode((value,))
@staticmethod
def fail() -> 'LeafNode[Any]':
return LeafNode([])
def is_complete(self) -> bool:
return len(matches) > 0
def __repr__(self) -> str:
return '{..}' if len(self.matches) else 'X'
@dataclass(frozen=True)
class IntNode(Generic[A]):
location: StructurePath
specific_trees: 'Mapping[int, MatchTree[A]]'
fallback_tree: 'MatchTree[A]'
def merge_each_child(self, other: 'MatchTree[A]'):
return IntNode(
self.location,
{i: merge_trees(tree, other) for (i, tree) in self.specific_trees.items()},
merge_trees(self.fallback_tree, other)
)
def is_complete(self) -> bool:
return all(
subtree.is_complete() for subtree in self.specific_trees.values()
) and self.fallback_tree.is_complete()
def __repr__(self) -> str:
return repr(self.location) + '{' + ','.join(f'{i}: {repr(mt)}' for i, mt in self.specific_trees.items()) + f', _: {self.fallback_tree}' + '}'
def merge_trees(t1: 'MatchTree[A]', t2: 'MatchTree[A]') -> 'MatchTree[A]':
match (t1, t2):
case (IntNode(location1, specific_trees1, fallback_tree1) as tree1, IntNode(location2, specific_trees2, fallback_tree2) as tree2):
if location1 == location2:
return IntNode(
location1,
{
i:
(
merge_trees(specific_trees1[i], specific_trees2[i])
if i in specific_trees1 and i in specific_trees2 else
merge_trees(specific_trees1[i], fallback_tree2)
if i in specific_trees1 and i not in specific_trees2 else
merge_trees(fallback_tree1, specific_trees2[i])
)
for i in (*specific_trees1.keys(), *specific_trees2.keys())
},
merge_trees(fallback_tree1, fallback_tree2),
)
elif list(location1) < list(location2):
return tree1.merge_each_child(tree2)
else:
return tree2.merge_each_child(tree1)
case (IntNode() as int_node, match_node) | (match_node, IntNode() as int_node):
return int_node.merge_each_child(match_node)
case (LeafNode(matches1), LeafNode(matches2)):
return LeafNode((*matches1, *matches2))
raise Exception('Unreachable')
def merge_all_trees(trees: 'Iterable[MatchTree[A]]') -> 'MatchTree[A]':
return reduce(merge_trees, trees, LeafNode.fail())