Clean up some type errors

And in doing so, fix some real errors and remove some unused code
This commit is contained in:
Emi Simpson 2024-03-15 19:20:19 -04:00
parent 7b59e21183
commit f04553ac4d
Signed by: Emi
GPG Key ID: A12F2C2FFDC3D847
4 changed files with 27 additions and 34 deletions

View File

@ -829,7 +829,7 @@ def sequence(s: Sequence[Result[A, B]]) -> Result[Sequence[A], B]:
Err('Oops!') Err('Oops!')
""" """
if all(s): if all(s):
return Ok(list(map(unwrap_r, s))) return Ok(tuple(map(unwrap_r, s)))
else: else:
o = next(filter(not_, s)) o = next(filter(not_, s))
assert isinstance(o, Err) assert isinstance(o, Err)

27
ir.py
View File

@ -2,7 +2,8 @@ from emis_funky_funktions import *
from typing import Collection, Mapping, Sequence, Tuple, TypeAlias from typing import Collection, Mapping, Sequence, Tuple, TypeAlias
from functools import reduce from functools import reduce
from match_tree import MatchException, StructurePath, LeafNode, merge_all_trees, IntNode from match_tree import MatchTree, MatchException, StructurePath, LeafNode, merge_all_trees, IntNode, EMPTY_STRUCT_PATH, FAIL_NODE
from patterns import Pattern
import types_ import types_
@ -91,7 +92,7 @@ BUILTIN_SUBSTITUTIONS: Sequence[Tuple[str, Expression]] = (
@dataclass(frozen=True) @dataclass(frozen=True)
class Function: class Function:
forms: 'Sequence[Tuple[patterns.Pattern, Expression]]' forms: 'Sequence[Tuple[Pattern, Expression]]'
def subst(self, expression: Expression, variable: str) -> Expression: def subst(self, expression: Expression, variable: str) -> Expression:
return Function([ return Function([
@ -105,12 +106,8 @@ class Function:
def step(self) -> Option[Expression]: def step(self) -> Option[Expression]:
return None return None
def function_to_tree(self) -> 'Result[Expression, MatchException]':
subtrees = (patt.match_tree([], LeafNode.from_value(expr)) for (patt, expr) in self.forms)
return reduce(merge_trees, subtrees, LeafNode(tuple()))
def eliminate(self, v: Expression) -> Result[Expression, MatchException]: def eliminate(self, v: Expression) -> Result[Expression, MatchException]:
match_trees = tuple(pattern.match_tree([], LeafNode.from_value(bindings_to_lets(pattern.bindings(), v, body))) for (pattern, body) in self.forms) match_trees = tuple(pattern.match_tree(EMPTY_STRUCT_PATH, LeafNode.from_value(bindings_to_lets(pattern.bindings(), v, body))) for (pattern, body) in self.forms)
unified_match_tree = merge_all_trees(match_trees) unified_match_tree = merge_all_trees(match_trees)
return compile_tree(unified_match_tree, v) return compile_tree(unified_match_tree, v)
@ -265,6 +262,7 @@ class Switch:
def subst(self, expression: Expression, variable: str) -> Expression: def subst(self, expression: Expression, variable: str) -> Expression:
return Switch( return Switch(
{i: e.subst(expression, variable) for i, e in self.branches.items()}, {i: e.subst(expression, variable) for i, e in self.branches.items()},
self.fallback,
self.switching_on.subst(expression, variable)) self.switching_on.subst(expression, variable))
def is_value(self) -> bool: def is_value(self) -> bool:
@ -273,7 +271,7 @@ class Switch:
def step(self) -> Option[Expression]: def step(self) -> Option[Expression]:
match self.switching_on.step(): match self.switching_on.step():
case Some(switch_expr_stepped): case Some(switch_expr_stepped):
return Switch(self.branches, switch_expr_stepped) return Some(Switch(self.branches, self.fallback, switch_expr_stepped))
case None: case None:
match self.switching_on: match self.switching_on:
case Int(n): case Int(n):
@ -295,15 +293,6 @@ class Switch:
) + f':{self.fallback.codegen()}' ) + f':{self.fallback.codegen()}'
def compile_tree(tree: 'MatchTree[Expression]', match_against: Expression) -> Result[Expression, MatchException]: def compile_tree(tree: 'MatchTree[Expression]', match_against: Expression) -> Result[Expression, MatchException]:
def match_int_builtin(lookup: Callable[[int], Expression]) -> Callable[[Expression], Option[Expression]]:
def match_inner(i: Expression) -> Expression:
match i:
case Int(value):
return lookup(value)
raise Exception('Bad type! Eep!')
return match_inner
match tree: match tree:
case LeafNode([match]): case LeafNode([match]):
return Ok(match) return Ok(match)
@ -322,6 +311,7 @@ def compile_tree(tree: 'MatchTree[Expression]', match_against: Expression) -> Re
return Err(e) return Err(e)
case Ok(fallback): case Ok(fallback):
return Ok(Switch(dict(zip(specific_trees.keys(), exprs)), fallback, match_against)) return Ok(Switch(dict(zip(specific_trees.keys(), exprs)), fallback, match_against))
raise Exception('Unreachable')
def location_to_ir(location: StructurePath) -> Callable[[Expression], Expression]: def location_to_ir(location: StructurePath) -> Callable[[Expression], Expression]:
def access_location(part: int) -> Callable[[Expression], Expression]: def access_location(part: int) -> Callable[[Expression], Expression]:
@ -340,12 +330,13 @@ def location_to_ir(location: StructurePath) -> Callable[[Expression], Expression
return c(location_to_ir(StructurePath(rest_location)), access_location(part)) return c(location_to_ir(StructurePath(rest_location)), access_location(part))
raise Exception('Unreachable') raise Exception('Unreachable')
def bindings_to_lets(bindings: Sequence[Tuple[str, StructurePath]], deconstructing_term: Expression, body_expr: Expression) -> Expression: def bindings_to_lets(bindings: Collection[Tuple[str, StructurePath]], deconstructing_term: Expression, body_expr: Expression) -> Expression:
match bindings: match bindings:
case []: case []:
return body_expr return body_expr
case [(binding_name, location), *rest]: case [(binding_name, location), *rest]:
return LetBinding(binding_name, location_to_ir(location)(deconstructing_term), bindings_to_lets(rest, deconstructing_term, body_expr)) return LetBinding(binding_name, location_to_ir(location)(deconstructing_term), bindings_to_lets(rest, deconstructing_term, body_expr))
raise Exception('Unreachable')
def subst_all(bindings: Sequence[Tuple[str, Expression]], body: Expression) -> Expression: def subst_all(bindings: Sequence[Tuple[str, Expression]], body: Expression) -> Expression:
match bindings: match bindings:

View File

@ -13,6 +13,8 @@ PatternID = NewType('PatternID', 'int')
MatchTree: TypeAlias = 'LeafNode[A] | IntNode[A]' MatchTree: TypeAlias = 'LeafNode[A] | IntNode[A]'
EMPTY_STRUCT_PATH: StructurePath = StructurePath(tuple())
class MatchException(IntEnum): class MatchException(IntEnum):
Incomplete = auto() Incomplete = auto()
Ambiguous = auto() Ambiguous = auto()
@ -25,19 +27,17 @@ class LeafNode(Generic[A]):
def from_value(value: A) -> 'LeafNode[A]': def from_value(value: A) -> 'LeafNode[A]':
return LeafNode((value,)) return LeafNode((value,))
@staticmethod
def fail() -> 'LeafNode[Any]':
return LeafNode([])
def is_complete(self) -> bool: def is_complete(self) -> bool:
return len(matches) > 0 return len(self.matches) > 0
def __repr__(self) -> str: def __repr__(self) -> str:
return '{..}' if len(self.matches) else 'X' return '{..}' if len(self.matches) else 'X'
def fill_empty_leaves(self, with_values: Sequence[A]) -> MatchTree: def fill_empty_leaves(self, with_values: Collection[A]) -> MatchTree:
return self if len(self.matches) else LeafNode(with_values) return self if len(self.matches) else LeafNode(with_values)
FAIL_NODE: LeafNode[Any] = LeafNode(tuple())
@dataclass(frozen=True) @dataclass(frozen=True)
class IntNode(Generic[A]): class IntNode(Generic[A]):
location: StructurePath location: StructurePath
@ -51,7 +51,7 @@ class IntNode(Generic[A]):
merge_trees(self.fallback_tree, other) merge_trees(self.fallback_tree, other)
) )
def fill_empty_leaves(self, with_values: Sequence[A]) -> MatchTree: def fill_empty_leaves(self, with_values: Collection[A]) -> MatchTree:
return IntNode( return IntNode(
self.location, self.location,
{i: tree.fill_empty_leaves(with_values) for (i, tree) in self.specific_trees.items()}, {i: tree.fill_empty_leaves(with_values) for (i, tree) in self.specific_trees.items()},
@ -70,7 +70,8 @@ def merge_trees(t1: 'MatchTree[A]', t2: 'MatchTree[A]') -> 'MatchTree[A]':
match (t1, t2): match (t1, t2):
case (LeafNode(matches1), LeafNode(matches2)): case (LeafNode(matches1), LeafNode(matches2)):
return LeafNode((*matches1, *matches2)) return LeafNode((*matches1, *matches2))
case (LeafNode(matches), other_node) | (other_node, LeafNode(matches)): case (LeafNode(matches), other_node) | (other_node, LeafNode(matches)): #type: ignore
# For ignore, see: https://github.com/python/mypy/issues/13950
return other_node.fill_empty_leaves(matches) return other_node.fill_empty_leaves(matches)
case (IntNode(location1, specific_trees1, fallback_tree1) as tree1, IntNode(location2, specific_trees2, fallback_tree2) as tree2): case (IntNode(location1, specific_trees1, fallback_tree1) as tree1, IntNode(location2, specific_trees2, fallback_tree2) as tree2):
if location1 == location2: if location1 == location2:
@ -98,4 +99,5 @@ def merge_trees(t1: 'MatchTree[A]', t2: 'MatchTree[A]') -> 'MatchTree[A]':
raise Exception('Unreachable') raise Exception('Unreachable')
def merge_all_trees(trees: 'Iterable[MatchTree[A]]') -> 'MatchTree[A]': def merge_all_trees(trees: 'Iterable[MatchTree[A]]') -> 'MatchTree[A]':
return reduce(merge_trees, trees, LeafNode.fail()) fail_node: MatchTree[A] = FAIL_NODE # Annotation or type checking purposes
return reduce(merge_trees, trees, fail_node)

View File

@ -3,7 +3,7 @@ from emis_funky_funktions import *
from typing import Collection, Mapping, Sequence, Tuple, TypeAlias from typing import Collection, Mapping, Sequence, Tuple, TypeAlias
import types_ import types_
from match_tree import LeafNode, IntNode, MatchTree, StructurePath from match_tree import LeafNode, IntNode, MatchTree, StructurePath, FAIL_NODE, EMPTY_STRUCT_PATH
Pattern: TypeAlias = 'NamePattern | IntPattern | SPattern | IgnorePattern' Pattern: TypeAlias = 'NamePattern | IntPattern | SPattern | IgnorePattern'
@ -18,7 +18,7 @@ class NamePattern:
return success_leaf return success_leaf
def bindings(self) -> Collection[Tuple[str, StructurePath]]: def bindings(self) -> Collection[Tuple[str, StructurePath]]:
return ((self.name, tuple()),) return ((self.name, EMPTY_STRUCT_PATH),)
def binds(self, var: str) -> bool: def binds(self, var: str) -> bool:
""" """
@ -83,15 +83,15 @@ class SPattern:
# If the child check is also an int node on pred of our current location: # If the child check is also an int node on pred of our current location:
# "raise" that node to be at our current level. # "raise" that node to be at our current level.
if child_location == (*location, 0): if child_location == (*location, 0):
return IntNode(location, dict(((0, LeafNode.fail()), *((1 + v, mt) for v, mt in trees.items()))), fallback) return IntNode(location, dict(((0, FAIL_NODE), *((1 + v, mt) for v, mt in trees.items()))), fallback)
else: else:
return IntNode(location, {0: LeafNode.fail()}, child) return IntNode(location, {0: FAIL_NODE}, child)
case child: case child:
return IntNode(location, {0: LeafNode.fail()}, child) return IntNode(location, {0: FAIL_NODE}, child)
def bindings(self) -> Collection[Tuple[str, StructurePath]]: def bindings(self) -> Collection[Tuple[str, StructurePath]]:
# Prepend each binding path of the child pattern with 0 (i.e. subtract one from an int) # Prepend each binding path of the child pattern with 0 (i.e. subtract one from an int)
return tuple((name, (0, *path)) for (name, path) in self.pred.bindings()) return tuple((name, StructurePath((0, *path))) for (name, path) in self.pred.bindings())
def binds(self, var: str) -> bool: def binds(self, var: str) -> bool:
""" """