from emis_funky_funktions import * from dataclasses import dataclass, field from heapq import heappop, heappush from typing import Callable, Generic, List, Sequence, Set, Tuple, TypeVar S = TypeVar('S') def _heappush_all( heap: List[S], new_items: Iterable[S] ) -> List[S]: "A shorthand for calling `heappush` with several new items" for item in new_items: heappush(heap, item) return heap @dataclass(frozen=True, order=True) class _FrontierNode(Generic[S]): estimated_final_cost: int total_cost: int node: S = field(compare=False) path: Tuple[S, ...] = field(compare=False) def pathfind( neighbors: Callable[[S], Sequence[Tuple[S, int]]], heuristic: Callable[[S], int], goal: Callable[[S], bool], start_state: S ) -> Option[Tuple[List[S], int]]: """ Perform an A* search over an arbitrary search space Arguments: neighbors: Given a state, this function should return all neighboring states along with the costs of moving to that space from the given state. heuristic: Given a state, this function should estimate the cost of travelling from that state to the goal state. goal: Should return true only for the goal state. start_state: The state that pathfinding should start from Returns: If no path is available: None If pathfinding succeeds: A list including the path taken to get to the goal along with the total cost of that path Example: Navigate from the top-left square to the top-right square, where the cost of moving is the number. >>> map = [ ... [ 8, 1, 1, 1, 9, 1, 1, 0 ], ... [ 8, 1, 1, 1, 9, 1, 999, 1 ], ... [ 1, 1, 1, 1, 9, 1, 1, 1 ], ... [ 1, 1, 1, 1, 9, 1, 1, 1 ], ... [ 1, 1, 30, 1, 5, 1, 1, 999 ], ... [ 1, 1, 999, 1, 5, 1, 1, 1 ], ... [ 1, 1, 999, 1, 5, 1, 1, 1 ], ... [ 0, 1, 999, 1, 1, 1, 1, 1 ] ... ] >>> neighbors = lambda l: [ ... ((nx, ny), map[ny][nx]) # Tuple of (x, y) and the cost ... for (nx, ny) in ( ... # Enumerate all adjacent squares (even illegal ones) ... (l[0] + dir_x, l[1] + dir_y) ... for (dir_x, dir_y) in [(-1, 0), (1, 0), (0, -1), (0, 1)] ... ) ... if nx >= 0 and nx < 8 and ny >= 0 and ny < 8 ... ] >>> heuristic = lambda l: 7 - l[0] + l[1] >>> goal = lambda l: l == (7, 0) >>> pathfind(neighbors, heuristic, goal, (0, 7)) #doctest: +NORMALIZE_WHITESPACE Some((((0, 7), (1, 7), (1, 6), (1, 5), (1, 4), (1, 3), (2, 3), (3, 3), (3, 4), (4, 4), (5, 4), (6, 4), (6, 3), (6, 2), (7, 2), (7, 1), (7, 0)), 19)) """ @tco_rec def pathfind_inner( frontier: List[_FrontierNode[S]], visited: Set[S] ) -> Return[Option[Tuple[Tuple[S, ...], int]]] | Recur[[List[_FrontierNode[S]], Set[S]]]: # Don't look at this in mypy # The types check out but mypy is REALLY bad at unifying types match try_(ident, heappop, frontier): case Err(_): return Return(None) case Ok(current) if current.node in visited: return Recur(frontier, visited) case Ok(current): new_path = (*current.path, current.node) if goal(current.node): return Return(Some((new_path, current.total_cost))) else: visited.add(current.node) return Recur( _heappush_all( frontier, [ _FrontierNode( current.total_cost + cost + heuristic(node), current.total_cost + cost, node, new_path ) for (node, cost) in neighbors(current.node) if node not in visited ] ), visited ) return pathfind_inner([_FrontierNode(0, 0, start_state, tuple())], set()) if __name__ == '__main__': import doctest doctest.testmod()