From 666488ce0f8fecabf3582cb44f98188c98fa1b18 Mon Sep 17 00:00:00 2001 From: Emi Simpson Date: Fri, 15 Mar 2024 17:13:14 -0400 Subject: [PATCH] Fixed match tree merges creating ambiguous trees Now match tree merging favors the most specific match possible --- match_tree.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/match_tree.py b/match_tree.py index 3628e06..f97b58b 100644 --- a/match_tree.py +++ b/match_tree.py @@ -35,6 +35,9 @@ class LeafNode(Generic[A]): def __repr__(self) -> str: return '{..}' if len(self.matches) else 'X' + def fill_empty_leaves(self, with_values: Sequence[A]) -> MatchTree: + return self if len(self.matches) else LeafNode(with_values) + @dataclass(frozen=True) class IntNode(Generic[A]): location: StructurePath @@ -48,6 +51,13 @@ class IntNode(Generic[A]): merge_trees(self.fallback_tree, other) ) + def fill_empty_leaves(self, with_values: Sequence[A]) -> MatchTree: + return IntNode( + self.location, + {i: tree.fill_empty_leaves(with_values) for (i, tree) in self.specific_trees.items()}, + self.fallback_tree.fill_empty_leaves(with_values) + ) + def is_complete(self) -> bool: return all( subtree.is_complete() for subtree in self.specific_trees.values() @@ -58,6 +68,10 @@ class IntNode(Generic[A]): def merge_trees(t1: 'MatchTree[A]', t2: 'MatchTree[A]') -> 'MatchTree[A]': match (t1, t2): + case (LeafNode(matches1), LeafNode(matches2)): + return LeafNode((*matches1, *matches2)) + case (LeafNode(matches), other_node) | (other_node, LeafNode(matches)): + return other_node.fill_empty_leaves(matches) case (IntNode(location1, specific_trees1, fallback_tree1) as tree1, IntNode(location2, specific_trees2, fallback_tree2) as tree2): if location1 == location2: return IntNode( @@ -81,10 +95,6 @@ def merge_trees(t1: 'MatchTree[A]', t2: 'MatchTree[A]') -> 'MatchTree[A]': return tree1.merge_each_child(tree2) else: return tree2.merge_each_child(tree1) - case (IntNode() as int_node, match_node) | (match_node, IntNode() as int_node): - return int_node.merge_each_child(match_node) - case (LeafNode(matches1), LeafNode(matches2)): - return LeafNode((*matches1, *matches2)) raise Exception('Unreachable') def merge_all_trees(trees: 'Iterable[MatchTree[A]]') -> 'MatchTree[A]':