JSON-Lang/match_tree.py

103 lines
3.4 KiB
Python

from emis_funky_funktions import *
from collections.abc import Collection, Mapping, Sequence
from dataclasses import dataclass
from enum import auto, IntEnum
from typing import Callable, Collection, NewType, TypeAlias
from functools import reduce
StructurePath = NewType('StructurePath', 'Sequence[int]')
Bindings = NewType('Bindings', 'Mapping[str, StructurePath]')
PatternID = NewType('PatternID', 'int')
MatchTree: TypeAlias = 'LeafNode[A] | IntNode[A]'
EMPTY_STRUCT_PATH: StructurePath = StructurePath(tuple())
class MatchException(IntEnum):
Incomplete = auto()
Ambiguous = auto()
@dataclass(frozen=True)
class LeafNode(Generic[A]):
matches: Collection[A]
@staticmethod
def from_value(value: A) -> 'LeafNode[A]':
return LeafNode((value,))
def is_complete(self) -> bool:
return len(self.matches) > 0
def __repr__(self) -> str:
return '{..}' if len(self.matches) else 'X'
def fill_empty_leaves(self, with_values: Collection[A]) -> MatchTree:
return self if len(self.matches) else LeafNode(with_values)
FAIL_NODE: LeafNode[Any] = LeafNode(tuple())
@dataclass(frozen=True)
class IntNode(Generic[A]):
location: StructurePath
specific_trees: 'Mapping[int, MatchTree[A]]'
fallback_tree: 'MatchTree[A]'
def merge_each_child(self, other: 'MatchTree[A]'):
return IntNode(
self.location,
{i: merge_trees(tree, other) for (i, tree) in self.specific_trees.items()},
merge_trees(self.fallback_tree, other)
)
def fill_empty_leaves(self, with_values: Collection[A]) -> MatchTree:
return IntNode(
self.location,
{i: tree.fill_empty_leaves(with_values) for (i, tree) in self.specific_trees.items()},
self.fallback_tree.fill_empty_leaves(with_values)
)
def is_complete(self) -> bool:
return all(
subtree.is_complete() for subtree in self.specific_trees.values()
) and self.fallback_tree.is_complete()
def __repr__(self) -> str:
return repr(self.location) + '{' + ','.join(f'{i}: {repr(mt)}' for i, mt in self.specific_trees.items()) + f', _: {self.fallback_tree}' + '}'
def merge_trees(t1: 'MatchTree[A]', t2: 'MatchTree[A]') -> 'MatchTree[A]':
match (t1, t2):
case (LeafNode(matches1), LeafNode(matches2)):
return LeafNode((*matches1, *matches2))
case (LeafNode(matches), other_node) | (other_node, LeafNode(matches)): #type: ignore
# For ignore, see: https://github.com/python/mypy/issues/13950
return other_node.fill_empty_leaves(matches)
case (IntNode(location1, specific_trees1, fallback_tree1) as tree1, IntNode(location2, specific_trees2, fallback_tree2) as tree2):
if location1 == location2:
return IntNode(
location1,
{
i:
(
merge_trees(specific_trees1[i], specific_trees2[i])
if i in specific_trees1 and i in specific_trees2 else
merge_trees(specific_trees1[i], fallback_tree2)
if i in specific_trees1 and i not in specific_trees2 else
merge_trees(fallback_tree1, specific_trees2[i])
)
for i in (*specific_trees1.keys(), *specific_trees2.keys())
},
merge_trees(fallback_tree1, fallback_tree2),
)
elif list(location1) < list(location2):
return tree1.merge_each_child(tree2)
else:
return tree2.merge_each_child(tree1)
raise Exception('Unreachable')
def merge_all_trees(trees: 'Iterable[MatchTree[A]]') -> 'MatchTree[A]':
fail_node: MatchTree[A] = FAIL_NODE # Annotation or type checking purposes
return reduce(merge_trees, trees, fail_node)