from emis_funky_funktions import * from collections.abc import Collection, Mapping, Sequence from dataclasses import dataclass from enum import auto, IntEnum from typing import Callable, Collection, NewType, TypeAlias from functools import reduce StructurePath = NewType('StructurePath', 'Sequence[int]') Bindings = NewType('Bindings', 'Mapping[str, StructurePath]') PatternID = NewType('PatternID', 'int') MatchTree: TypeAlias = 'LeafNode[A] | IntNode[A]' class MatchException(IntEnum): Incomplete = auto() Ambiguous = auto() @dataclass(frozen=True) class LeafNode(Generic[A]): matches: Collection[A] @staticmethod def from_value(value: A) -> 'LeafNode[A]': return LeafNode((value,)) @staticmethod def fail() -> 'LeafNode[Any]': return LeafNode([]) def is_complete(self) -> bool: return len(matches) > 0 def __repr__(self) -> str: return '{..}' if len(self.matches) else 'X' def fill_empty_leaves(self, with_values: Sequence[A]) -> MatchTree: return self if len(self.matches) else LeafNode(with_values) @dataclass(frozen=True) class IntNode(Generic[A]): location: StructurePath specific_trees: 'Mapping[int, MatchTree[A]]' fallback_tree: 'MatchTree[A]' def merge_each_child(self, other: 'MatchTree[A]'): return IntNode( self.location, {i: merge_trees(tree, other) for (i, tree) in self.specific_trees.items()}, merge_trees(self.fallback_tree, other) ) def fill_empty_leaves(self, with_values: Sequence[A]) -> MatchTree: return IntNode( self.location, {i: tree.fill_empty_leaves(with_values) for (i, tree) in self.specific_trees.items()}, self.fallback_tree.fill_empty_leaves(with_values) ) def is_complete(self) -> bool: return all( subtree.is_complete() for subtree in self.specific_trees.values() ) and self.fallback_tree.is_complete() def __repr__(self) -> str: return repr(self.location) + '{' + ','.join(f'{i}: {repr(mt)}' for i, mt in self.specific_trees.items()) + f', _: {self.fallback_tree}' + '}' def merge_trees(t1: 'MatchTree[A]', t2: 'MatchTree[A]') -> 'MatchTree[A]': match (t1, t2): case (LeafNode(matches1), LeafNode(matches2)): return LeafNode((*matches1, *matches2)) case (LeafNode(matches), other_node) | (other_node, LeafNode(matches)): return other_node.fill_empty_leaves(matches) case (IntNode(location1, specific_trees1, fallback_tree1) as tree1, IntNode(location2, specific_trees2, fallback_tree2) as tree2): if location1 == location2: return IntNode( location1, { i: ( merge_trees(specific_trees1[i], specific_trees2[i]) if i in specific_trees1 and i in specific_trees2 else merge_trees(specific_trees1[i], fallback_tree2) if i in specific_trees1 and i not in specific_trees2 else merge_trees(fallback_tree1, specific_trees2[i]) ) for i in (*specific_trees1.keys(), *specific_trees2.keys()) }, merge_trees(fallback_tree1, fallback_tree2), ) elif list(location1) < list(location2): return tree1.merge_each_child(tree2) else: return tree2.merge_each_child(tree1) raise Exception('Unreachable') def merge_all_trees(trees: 'Iterable[MatchTree[A]]') -> 'MatchTree[A]': return reduce(merge_trees, trees, LeafNode.fail())