from emis_funky_funktions import * from dataclasses import dataclass, field from heapq import heappop, heappush from operator import eq from typing import Callable, Generic, List, Sequence, Set, Tuple, TypeVar S = TypeVar('S') def _heappush_all( heap: List[S], new_items: Iterable[S] ) -> List[S]: "A shorthand for calling `heappush` with several new items" for item in new_items: heappush(heap, item) return heap @dataclass(frozen=True, order=True) class _FrontierNode(Generic[S]): estimated_final_cost: int total_cost: int node: S = field(compare=False) path: Tuple[S, ...] = field(compare=False) def pathfind( neighbors: Callable[[S], Sequence[Tuple[S, int]]], heuristic: Callable[[S], int], goal: Callable[[S], bool], start_state: S ) -> Option[Tuple[List[S], int]]: """ Perform an A* search over an arbitrary search space Arguments: neighbors: Given a state, this function should return all neighboring states along with the costs of moving to that space from the given state. heuristic: Given a state, this function should estimate the cost of travelling from that state to the goal state. goal: Should return true only for the goal state. start_state: The state that pathfinding should start from Returns: If no path is available: None If pathfinding succeeds: A list including the path taken to get to the goal along with the total cost of that path Example: Navigate from the top-left square to the top-right square, where the cost of moving is the number. >>> map = [ ... [ 8, 1, 1, 1, 9, 1, 1, 0 ], ... [ 8, 1, 1, 1, 9, 1, 999, 1 ], ... [ 1, 1, 1, 1, 9, 1, 1, 1 ], ... [ 1, 1, 1, 1, 9, 1, 1, 1 ], ... [ 1, 1, 30, 1, 5, 1, 1, 999 ], ... [ 1, 1, 999, 1, 5, 1, 1, 1 ], ... [ 1, 1, 999, 1, 5, 1, 1, 1 ], ... [ 0, 1, 999, 1, 1, 1, 1, 1 ] ... ] >>> neighbors = lambda l: [ ... ((nx, ny), map[ny][nx]) # Tuple of (x, y) and the cost ... for (nx, ny) in ( ... # Enumerate all adjacent squares (even illegal ones) ... (l[0] + dir_x, l[1] + dir_y) ... for (dir_x, dir_y) in [(-1, 0), (1, 0), (0, -1), (0, 1)] ... ) ... if nx >= 0 and nx < 8 and ny >= 0 and ny < 8 ... ] >>> heuristic = lambda l: 7 - l[0] + l[1] >>> goal = lambda l: l == (7, 0) >>> pathfind(neighbors, heuristic, goal, (0, 7)) #doctest: +NORMALIZE_WHITESPACE Some((((0, 7), (1, 7), (1, 6), (1, 5), (1, 4), (1, 3), (2, 3), (3, 3), (3, 4), (4, 4), (5, 4), (6, 4), (6, 3), (6, 2), (7, 2), (7, 1), (7, 0)), 19)) """ frontier, visited = ([_FrontierNode(0, 0, start_state, tuple())], set()) while True: # Don't look at this in mypy # The types check out but mypy is REALLY bad at unifying types match try_(ident, heappop, frontier): case Err(_): return None case Ok(current) if current.node in visited: pass #RECUR case Ok(current): new_path = (*current.path, current.node) if goal(current.node): return Some((new_path, current.total_cost)) else: visited.add(current.node) frontier, visited = ( #RECUR _heappush_all( frontier, [ _FrontierNode( current.total_cost + cost + heuristic(node), current.total_cost + cost, node, new_path ) for (node, cost) in neighbors(current.node) if node not in visited ] ), visited ) @tco_rec def pathfind_multi( neighbors: Callable[[S], Sequence[Tuple[S, int]]], heuristic: Callable[[S, S], int], checkpoints: List[S], prefix_moves: Tuple[Tuple[S, ...], int] = (tuple(), 0) ) -> Return[ Option[Tuple[Tuple[S, ...], int]] ] | Recur[[ Callable[[S], Sequence[Tuple[S, int]]], Callable[[S, S], int], List[S], Tuple[Tuple[S, ...], int] ]]: """ Pathfind a path between a series of states in sequence For each pair of adjacent nodes in the checkpoints list, a path between those two nodes will be found. The returned path passes through each provided node in order. >>> map = [ ... [ 8, 1, 1, 1, 9, 1, 1, 0 ], ... [ 8, 1, 1, 1, 9, 1, 999, 1 ], ... [ 1, 1, 1, 1, 9, 1, 1, 1 ], ... [ 1, 1, 1, 1, 9, 1, 1, 1 ], ... [ 1, 1, 30, 1, 5, 1, 1, 999 ], ... [ 1, 1, 999, 1, 5, 1, 1, 1 ], ... [ 1, 1, 999, 1, 5, 1, 1, 1 ], ... [ 0, 1, 999, 1, 1, 1, 1, 1 ] ... ] We re-use the neighbors & heuristic function we introduced in `pathfind()`. >>> neighbors = lambda l: [ ... ((nx, ny), map[ny][nx]) # Tuple of (x, y) and the cost ... for (nx, ny) in ( ... # Enumerate all adjacent squares (even illegal ones) ... (l[0] + dir_x, l[1] + dir_y) ... for (dir_x, dir_y) in [(-1, 0), (1, 0), (0, -1), (0, 1)] ... ) ... if nx >= 0 and nx < 8 and ny >= 0 and ny < 8 ... ] The heuristic function must provide a heuristic between two points, rather than a heuristic based on single point, as in `pathfind()`. If your heuristic function is asymmetric, note that the first argument is where we are pathing *to*, and the second is where we are pathing *from*. >>> heuristic = lambda f, t: abs(f[0] - t[0]) + abs(f[1] - t[1]) Now we pathfind from the bottom left corner, through the top left corner, then finish in the bottom right. >>> pathfind_multi(neighbors, heuristic, [(0, 7), (0, 0), (7, 7)]) #doctest: +NORMALIZE_WHITESPACE Some((((0, 7), (0, 6), (0, 5), (0, 4), (0, 3), (0, 2), (1, 2), (1, 1), (1, 0), (0, 0), (1, 0), (2, 0), (3, 0), (3, 1), (3, 2), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), (4, 7), (5, 7), (6, 7), (7, 7)), 30)) """ match checkpoints: case []: return Return(Some(prefix_moves)) case [single]: return Return(Some(((*prefix_moves[0], single), prefix_moves[1]))) case [start, goal, *next_goals]: match pathfind(neighbors, p(heuristic, goal), p(eq, goal), start): case None: return Return(None) case Some((path, cost)): return Recur( neighbors, heuristic, [goal, *next_goals], ( (*prefix_moves[0], *path[:-1]), prefix_moves[1] + cost ) ) if __name__ == '__main__': import doctest doctest.testmod()