From f04553ac4d1fdb6f3fd19c4443f7350b945cfe31 Mon Sep 17 00:00:00 2001 From: Emi Simpson Date: Fri, 15 Mar 2024 19:20:19 -0400 Subject: [PATCH] Clean up some type errors And in doing so, fix some real errors and remove some unused code --- emis_funky_funktions.py | 2 +- ir.py | 27 +++++++++------------------ match_tree.py | 20 +++++++++++--------- patterns.py | 12 ++++++------ 4 files changed, 27 insertions(+), 34 deletions(-) diff --git a/emis_funky_funktions.py b/emis_funky_funktions.py index bc30b8b..598d3f9 100644 --- a/emis_funky_funktions.py +++ b/emis_funky_funktions.py @@ -829,7 +829,7 @@ def sequence(s: Sequence[Result[A, B]]) -> Result[Sequence[A], B]: Err('Oops!') """ if all(s): - return Ok(list(map(unwrap_r, s))) + return Ok(tuple(map(unwrap_r, s))) else: o = next(filter(not_, s)) assert isinstance(o, Err) diff --git a/ir.py b/ir.py index b2b590f..141d8b5 100644 --- a/ir.py +++ b/ir.py @@ -2,7 +2,8 @@ from emis_funky_funktions import * from typing import Collection, Mapping, Sequence, Tuple, TypeAlias from functools import reduce -from match_tree import MatchException, StructurePath, LeafNode, merge_all_trees, IntNode +from match_tree import MatchTree, MatchException, StructurePath, LeafNode, merge_all_trees, IntNode, EMPTY_STRUCT_PATH, FAIL_NODE +from patterns import Pattern import types_ @@ -91,7 +92,7 @@ BUILTIN_SUBSTITUTIONS: Sequence[Tuple[str, Expression]] = ( @dataclass(frozen=True) class Function: - forms: 'Sequence[Tuple[patterns.Pattern, Expression]]' + forms: 'Sequence[Tuple[Pattern, Expression]]' def subst(self, expression: Expression, variable: str) -> Expression: return Function([ @@ -105,12 +106,8 @@ class Function: def step(self) -> Option[Expression]: return None - def function_to_tree(self) -> 'Result[Expression, MatchException]': - subtrees = (patt.match_tree([], LeafNode.from_value(expr)) for (patt, expr) in self.forms) - return reduce(merge_trees, subtrees, LeafNode(tuple())) - def eliminate(self, v: Expression) -> Result[Expression, MatchException]: - match_trees = tuple(pattern.match_tree([], LeafNode.from_value(bindings_to_lets(pattern.bindings(), v, body))) for (pattern, body) in self.forms) + match_trees = tuple(pattern.match_tree(EMPTY_STRUCT_PATH, LeafNode.from_value(bindings_to_lets(pattern.bindings(), v, body))) for (pattern, body) in self.forms) unified_match_tree = merge_all_trees(match_trees) return compile_tree(unified_match_tree, v) @@ -265,6 +262,7 @@ class Switch: def subst(self, expression: Expression, variable: str) -> Expression: return Switch( {i: e.subst(expression, variable) for i, e in self.branches.items()}, + self.fallback, self.switching_on.subst(expression, variable)) def is_value(self) -> bool: @@ -273,7 +271,7 @@ class Switch: def step(self) -> Option[Expression]: match self.switching_on.step(): case Some(switch_expr_stepped): - return Switch(self.branches, switch_expr_stepped) + return Some(Switch(self.branches, self.fallback, switch_expr_stepped)) case None: match self.switching_on: case Int(n): @@ -295,15 +293,6 @@ class Switch: ) + f':{self.fallback.codegen()}' def compile_tree(tree: 'MatchTree[Expression]', match_against: Expression) -> Result[Expression, MatchException]: - - def match_int_builtin(lookup: Callable[[int], Expression]) -> Callable[[Expression], Option[Expression]]: - def match_inner(i: Expression) -> Expression: - match i: - case Int(value): - return lookup(value) - raise Exception('Bad type! Eep!') - return match_inner - match tree: case LeafNode([match]): return Ok(match) @@ -322,6 +311,7 @@ def compile_tree(tree: 'MatchTree[Expression]', match_against: Expression) -> Re return Err(e) case Ok(fallback): return Ok(Switch(dict(zip(specific_trees.keys(), exprs)), fallback, match_against)) + raise Exception('Unreachable') def location_to_ir(location: StructurePath) -> Callable[[Expression], Expression]: def access_location(part: int) -> Callable[[Expression], Expression]: @@ -340,12 +330,13 @@ def location_to_ir(location: StructurePath) -> Callable[[Expression], Expression return c(location_to_ir(StructurePath(rest_location)), access_location(part)) raise Exception('Unreachable') -def bindings_to_lets(bindings: Sequence[Tuple[str, StructurePath]], deconstructing_term: Expression, body_expr: Expression) -> Expression: +def bindings_to_lets(bindings: Collection[Tuple[str, StructurePath]], deconstructing_term: Expression, body_expr: Expression) -> Expression: match bindings: case []: return body_expr case [(binding_name, location), *rest]: return LetBinding(binding_name, location_to_ir(location)(deconstructing_term), bindings_to_lets(rest, deconstructing_term, body_expr)) + raise Exception('Unreachable') def subst_all(bindings: Sequence[Tuple[str, Expression]], body: Expression) -> Expression: match bindings: diff --git a/match_tree.py b/match_tree.py index f97b58b..1a891a1 100644 --- a/match_tree.py +++ b/match_tree.py @@ -13,6 +13,8 @@ PatternID = NewType('PatternID', 'int') MatchTree: TypeAlias = 'LeafNode[A] | IntNode[A]' +EMPTY_STRUCT_PATH: StructurePath = StructurePath(tuple()) + class MatchException(IntEnum): Incomplete = auto() Ambiguous = auto() @@ -25,19 +27,17 @@ class LeafNode(Generic[A]): def from_value(value: A) -> 'LeafNode[A]': return LeafNode((value,)) - @staticmethod - def fail() -> 'LeafNode[Any]': - return LeafNode([]) - def is_complete(self) -> bool: - return len(matches) > 0 + return len(self.matches) > 0 def __repr__(self) -> str: return '{..}' if len(self.matches) else 'X' - def fill_empty_leaves(self, with_values: Sequence[A]) -> MatchTree: + def fill_empty_leaves(self, with_values: Collection[A]) -> MatchTree: return self if len(self.matches) else LeafNode(with_values) +FAIL_NODE: LeafNode[Any] = LeafNode(tuple()) + @dataclass(frozen=True) class IntNode(Generic[A]): location: StructurePath @@ -51,7 +51,7 @@ class IntNode(Generic[A]): merge_trees(self.fallback_tree, other) ) - def fill_empty_leaves(self, with_values: Sequence[A]) -> MatchTree: + def fill_empty_leaves(self, with_values: Collection[A]) -> MatchTree: return IntNode( self.location, {i: tree.fill_empty_leaves(with_values) for (i, tree) in self.specific_trees.items()}, @@ -70,7 +70,8 @@ def merge_trees(t1: 'MatchTree[A]', t2: 'MatchTree[A]') -> 'MatchTree[A]': match (t1, t2): case (LeafNode(matches1), LeafNode(matches2)): return LeafNode((*matches1, *matches2)) - case (LeafNode(matches), other_node) | (other_node, LeafNode(matches)): + case (LeafNode(matches), other_node) | (other_node, LeafNode(matches)): #type: ignore + # For ignore, see: https://github.com/python/mypy/issues/13950 return other_node.fill_empty_leaves(matches) case (IntNode(location1, specific_trees1, fallback_tree1) as tree1, IntNode(location2, specific_trees2, fallback_tree2) as tree2): if location1 == location2: @@ -98,4 +99,5 @@ def merge_trees(t1: 'MatchTree[A]', t2: 'MatchTree[A]') -> 'MatchTree[A]': raise Exception('Unreachable') def merge_all_trees(trees: 'Iterable[MatchTree[A]]') -> 'MatchTree[A]': - return reduce(merge_trees, trees, LeafNode.fail()) \ No newline at end of file + fail_node: MatchTree[A] = FAIL_NODE # Annotation or type checking purposes + return reduce(merge_trees, trees, fail_node) \ No newline at end of file diff --git a/patterns.py b/patterns.py index 20e2007..5dfc271 100644 --- a/patterns.py +++ b/patterns.py @@ -3,7 +3,7 @@ from emis_funky_funktions import * from typing import Collection, Mapping, Sequence, Tuple, TypeAlias import types_ -from match_tree import LeafNode, IntNode, MatchTree, StructurePath +from match_tree import LeafNode, IntNode, MatchTree, StructurePath, FAIL_NODE, EMPTY_STRUCT_PATH Pattern: TypeAlias = 'NamePattern | IntPattern | SPattern | IgnorePattern' @@ -18,7 +18,7 @@ class NamePattern: return success_leaf def bindings(self) -> Collection[Tuple[str, StructurePath]]: - return ((self.name, tuple()),) + return ((self.name, EMPTY_STRUCT_PATH),) def binds(self, var: str) -> bool: """ @@ -83,15 +83,15 @@ class SPattern: # If the child check is also an int node on pred of our current location: # "raise" that node to be at our current level. if child_location == (*location, 0): - return IntNode(location, dict(((0, LeafNode.fail()), *((1 + v, mt) for v, mt in trees.items()))), fallback) + return IntNode(location, dict(((0, FAIL_NODE), *((1 + v, mt) for v, mt in trees.items()))), fallback) else: - return IntNode(location, {0: LeafNode.fail()}, child) + return IntNode(location, {0: FAIL_NODE}, child) case child: - return IntNode(location, {0: LeafNode.fail()}, child) + return IntNode(location, {0: FAIL_NODE}, child) def bindings(self) -> Collection[Tuple[str, StructurePath]]: # Prepend each binding path of the child pattern with 0 (i.e. subtract one from an int) - return tuple((name, (0, *path)) for (name, path) in self.pred.bindings()) + return tuple((name, StructurePath((0, *path))) for (name, path) in self.pred.bindings()) def binds(self, var: str) -> bool: """