From 3811e55711134ed883d9b716eb35a4b73e04d9b2 Mon Sep 17 00:00:00 2001 From: Emi Simpson Date: Fri, 10 Feb 2023 20:44:23 -0500 Subject: [PATCH] Added an A* routing algorithm --- a_star.py | 117 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 a_star.py diff --git a/a_star.py b/a_star.py new file mode 100644 index 0000000..7a9bc33 --- /dev/null +++ b/a_star.py @@ -0,0 +1,117 @@ +from emis_funky_funktions import * + +from dataclasses import dataclass, field +from heapq import heappop, heappush +from typing import Callable, Generic, List, Sequence, Set, Tuple, TypeVar + +S = TypeVar('S') + +def _heappush_all( + heap: List[S], + new_items: Iterable[S] +) -> List[S]: + "A shorthand for calling `heappush` with several new items" + for item in new_items: + heappush(heap, item) + return heap + +@dataclass(frozen=True, order=True) +class _FrontierNode(Generic[S]): + estimated_final_cost: int + total_cost: int + node: S = field(compare=False) + path: Tuple[S, ...] = field(compare=False) + +def pathfind( + neighbors: Callable[[S], Sequence[Tuple[S, int]]], + heuristic: Callable[[S], int], + goal: Callable[[S], bool], + start_state: S +) -> Option[Tuple[List[S], int]]: + """ + Perform an A* search over an arbitrary search space + + Arguments: + neighbors: Given a state, this function should return all neighboring states + along with the costs of moving to that space from the given state. + heuristic: Given a state, this function should estimate the cost of travelling + from that state to the goal state. + goal: Should return true only for the goal state. + start_state: The state that pathfinding should start from + + Returns: + If no path is available: + None + If pathfinding succeeds: + A list including the path taken to get to the goal along with the total cost + of that path + + Example: + Navigate from the top-left square to the top-right square, where the cost of + moving is the number. + + >>> map = [ + ... [ 8, 1, 1, 1, 9, 1, 1, 0 ], + ... [ 8, 1, 1, 1, 9, 1, 999, 1 ], + ... [ 1, 1, 1, 1, 9, 1, 1, 1 ], + ... [ 1, 1, 1, 1, 9, 1, 1, 1 ], + ... [ 1, 1, 30, 1, 5, 1, 1, 999 ], + ... [ 1, 1, 999, 1, 5, 1, 1, 1 ], + ... [ 1, 1, 999, 1, 5, 1, 1, 1 ], + ... [ 0, 1, 999, 1, 1, 1, 1, 1 ] + ... ] + >>> neighbors = lambda l: [ + ... ((nx, ny), map[ny][nx]) # Tuple of (x, y) and the cost + ... for (nx, ny) in ( + ... # Enumerate all adjacent squares (even illegal ones) + ... (l[0] + dir_x, l[1] + dir_y) + ... for (dir_x, dir_y) in [(-1, 0), (1, 0), (0, -1), (0, 1)] + ... ) + ... if nx >= 0 and nx < 8 and ny >= 0 and ny < 8 + ... ] + >>> heuristic = lambda l: 7 - l[0] + l[1] + >>> goal = lambda l: l == (7, 0) + >>> pathfind(neighbors, heuristic, goal, (0, 7)) #doctest: +NORMALIZE_WHITESPACE + Some((((0, 7), (1, 7), (1, 6), (1, 5), (1, 4), (1, 3), + (2, 3), (3, 3), (3, 4), (4, 4), (5, 4), (6, 4), + (6, 3), (6, 2), (7, 2), (7, 1), (7, 0)), 19)) + """ + @tco_rec + def pathfind_inner( + frontier: List[_FrontierNode[S]], + visited: Set[S] + ) -> Return[Option[Tuple[Tuple[S, ...], int]]] | Recur[[List[_FrontierNode[S]], Set[S]]]: + # Don't look at this in mypy + # The types check out but mypy is REALLY bad at unifying types + match try_(ident, heappop, frontier): + case Err(_): + return Return(None) + case Ok(current) if current.node in visited: + return Recur(frontier, visited) + case Ok(current): + new_path = (*current.path, current.node) + if goal(current.node): + return Return(Some((new_path, current.total_cost))) + else: + visited.add(current.node) + return Recur( + _heappush_all( + frontier, + [ + _FrontierNode( + current.total_cost + cost + heuristic(node), + current.total_cost + cost, + node, + new_path + ) + for (node, cost) in neighbors(current.node) + if node not in visited + ] + ), + visited + ) + return pathfind_inner([_FrontierNode(0, 0, start_state, tuple())], set()) + +if __name__ == '__main__': + import doctest + doctest.testmod() \ No newline at end of file