91 lines
2.8 KiB
Python
91 lines
2.8 KiB
Python
|
from emis_funky_funktions import *
|
||
|
|
||
|
from collections.abc import Collection, Mapping, Sequence
|
||
|
from dataclasses import dataclass
|
||
|
from enum import auto, IntEnum
|
||
|
from typing import Callable, Collection, NewType, TypeAlias
|
||
|
from functools import reduce
|
||
|
|
||
|
StructurePath = NewType('StructurePath', 'Sequence[int]')
|
||
|
Bindings = NewType('Bindings', 'Mapping[str, StructurePath]')
|
||
|
PatternID = NewType('PatternID', 'int')
|
||
|
|
||
|
|
||
|
MatchTree: TypeAlias = 'LeafNode[A] | IntNode[A]'
|
||
|
|
||
|
class MatchException(IntEnum):
|
||
|
Incomplete = auto()
|
||
|
Ambiguous = auto()
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class LeafNode(Generic[A]):
|
||
|
matches: Collection[A]
|
||
|
|
||
|
@staticmethod
|
||
|
def from_value(value: A) -> 'LeafNode[A]':
|
||
|
return LeafNode((value,))
|
||
|
|
||
|
@staticmethod
|
||
|
def fail() -> 'LeafNode[Any]':
|
||
|
return LeafNode([])
|
||
|
|
||
|
def is_complete(self) -> bool:
|
||
|
return len(matches) > 0
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return '{..}' if len(self.matches) else 'X'
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class IntNode(Generic[A]):
|
||
|
location: StructurePath
|
||
|
specific_trees: 'Mapping[int, MatchTree[A]]'
|
||
|
fallback_tree: 'MatchTree[A]'
|
||
|
|
||
|
def merge_each_child(self, other: 'MatchTree[A]'):
|
||
|
return IntNode(
|
||
|
self.location,
|
||
|
{i: merge_trees(tree, other) for (i, tree) in self.specific_trees.items()},
|
||
|
merge_trees(self.fallback_tree, other)
|
||
|
)
|
||
|
|
||
|
def is_complete(self) -> bool:
|
||
|
return all(
|
||
|
subtree.is_complete() for subtree in self.specific_trees.values()
|
||
|
) and self.fallback_tree.is_complete()
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return repr(self.location) + '{' + ','.join(f'{i}: {repr(mt)}' for i, mt in self.specific_trees.items()) + f', _: {self.fallback_tree}' + '}'
|
||
|
|
||
|
def merge_trees(t1: 'MatchTree[A]', t2: 'MatchTree[A]') -> 'MatchTree[A]':
|
||
|
match (t1, t2):
|
||
|
case (IntNode(location1, specific_trees1, fallback_tree1) as tree1, IntNode(location2, specific_trees2, fallback_tree2) as tree2):
|
||
|
if location1 == location2:
|
||
|
return IntNode(
|
||
|
location1,
|
||
|
{
|
||
|
i:
|
||
|
(
|
||
|
merge_trees(specific_trees1[i], specific_trees2[i])
|
||
|
if i in specific_trees1 and i in specific_trees2 else
|
||
|
|
||
|
merge_trees(specific_trees1[i], fallback_tree2)
|
||
|
if i in specific_trees1 and i not in specific_trees2 else
|
||
|
|
||
|
merge_trees(fallback_tree1, specific_trees2[i])
|
||
|
)
|
||
|
for i in (*specific_trees1.keys(), *specific_trees2.keys())
|
||
|
},
|
||
|
merge_trees(fallback_tree1, fallback_tree2),
|
||
|
)
|
||
|
elif list(location1) < list(location2):
|
||
|
return tree1.merge_each_child(tree2)
|
||
|
else:
|
||
|
return tree2.merge_each_child(tree1)
|
||
|
case (IntNode() as int_node, match_node) | (match_node, IntNode() as int_node):
|
||
|
return int_node.merge_each_child(match_node)
|
||
|
case (LeafNode(matches1), LeafNode(matches2)):
|
||
|
return LeafNode((*matches1, *matches2))
|
||
|
raise Exception('Unreachable')
|
||
|
|
||
|
def merge_all_trees(trees: 'Iterable[MatchTree[A]]') -> 'MatchTree[A]':
|
||
|
return reduce(merge_trees, trees, LeafNode.fail())
|