117 lines
3.7 KiB
Python
117 lines
3.7 KiB
Python
|
from emis_funky_funktions import *
|
||
|
|
||
|
from dataclasses import dataclass, field
|
||
|
from heapq import heappop, heappush
|
||
|
from typing import Callable, Generic, List, Sequence, Set, Tuple, TypeVar
|
||
|
|
||
|
S = TypeVar('S')
|
||
|
|
||
|
def _heappush_all(
|
||
|
heap: List[S],
|
||
|
new_items: Iterable[S]
|
||
|
) -> List[S]:
|
||
|
"A shorthand for calling `heappush` with several new items"
|
||
|
for item in new_items:
|
||
|
heappush(heap, item)
|
||
|
return heap
|
||
|
|
||
|
@dataclass(frozen=True, order=True)
|
||
|
class _FrontierNode(Generic[S]):
|
||
|
estimated_final_cost: int
|
||
|
total_cost: int
|
||
|
node: S = field(compare=False)
|
||
|
path: Tuple[S, ...] = field(compare=False)
|
||
|
|
||
|
def pathfind(
|
||
|
neighbors: Callable[[S], Sequence[Tuple[S, int]]],
|
||
|
heuristic: Callable[[S], int],
|
||
|
goal: Callable[[S], bool],
|
||
|
start_state: S
|
||
|
) -> Option[Tuple[List[S], int]]:
|
||
|
"""
|
||
|
Perform an A* search over an arbitrary search space
|
||
|
|
||
|
Arguments:
|
||
|
neighbors: Given a state, this function should return all neighboring states
|
||
|
along with the costs of moving to that space from the given state.
|
||
|
heuristic: Given a state, this function should estimate the cost of travelling
|
||
|
from that state to the goal state.
|
||
|
goal: Should return true only for the goal state.
|
||
|
start_state: The state that pathfinding should start from
|
||
|
|
||
|
Returns:
|
||
|
If no path is available:
|
||
|
None
|
||
|
If pathfinding succeeds:
|
||
|
A list including the path taken to get to the goal along with the total cost
|
||
|
of that path
|
||
|
|
||
|
Example:
|
||
|
Navigate from the top-left square to the top-right square, where the cost of
|
||
|
moving is the number.
|
||
|
|
||
|
>>> map = [
|
||
|
... [ 8, 1, 1, 1, 9, 1, 1, 0 ],
|
||
|
... [ 8, 1, 1, 1, 9, 1, 999, 1 ],
|
||
|
... [ 1, 1, 1, 1, 9, 1, 1, 1 ],
|
||
|
... [ 1, 1, 1, 1, 9, 1, 1, 1 ],
|
||
|
... [ 1, 1, 30, 1, 5, 1, 1, 999 ],
|
||
|
... [ 1, 1, 999, 1, 5, 1, 1, 1 ],
|
||
|
... [ 1, 1, 999, 1, 5, 1, 1, 1 ],
|
||
|
... [ 0, 1, 999, 1, 1, 1, 1, 1 ]
|
||
|
... ]
|
||
|
>>> neighbors = lambda l: [
|
||
|
... ((nx, ny), map[ny][nx]) # Tuple of (x, y) and the cost
|
||
|
... for (nx, ny) in (
|
||
|
... # Enumerate all adjacent squares (even illegal ones)
|
||
|
... (l[0] + dir_x, l[1] + dir_y)
|
||
|
... for (dir_x, dir_y) in [(-1, 0), (1, 0), (0, -1), (0, 1)]
|
||
|
... )
|
||
|
... if nx >= 0 and nx < 8 and ny >= 0 and ny < 8
|
||
|
... ]
|
||
|
>>> heuristic = lambda l: 7 - l[0] + l[1]
|
||
|
>>> goal = lambda l: l == (7, 0)
|
||
|
>>> pathfind(neighbors, heuristic, goal, (0, 7)) #doctest: +NORMALIZE_WHITESPACE
|
||
|
Some((((0, 7), (1, 7), (1, 6), (1, 5), (1, 4), (1, 3),
|
||
|
(2, 3), (3, 3), (3, 4), (4, 4), (5, 4), (6, 4),
|
||
|
(6, 3), (6, 2), (7, 2), (7, 1), (7, 0)), 19))
|
||
|
"""
|
||
|
@tco_rec
|
||
|
def pathfind_inner(
|
||
|
frontier: List[_FrontierNode[S]],
|
||
|
visited: Set[S]
|
||
|
) -> Return[Option[Tuple[Tuple[S, ...], int]]] | Recur[[List[_FrontierNode[S]], Set[S]]]:
|
||
|
# Don't look at this in mypy
|
||
|
# The types check out but mypy is REALLY bad at unifying types
|
||
|
match try_(ident, heappop, frontier):
|
||
|
case Err(_):
|
||
|
return Return(None)
|
||
|
case Ok(current) if current.node in visited:
|
||
|
return Recur(frontier, visited)
|
||
|
case Ok(current):
|
||
|
new_path = (*current.path, current.node)
|
||
|
if goal(current.node):
|
||
|
return Return(Some((new_path, current.total_cost)))
|
||
|
else:
|
||
|
visited.add(current.node)
|
||
|
return Recur(
|
||
|
_heappush_all(
|
||
|
frontier,
|
||
|
[
|
||
|
_FrontierNode(
|
||
|
current.total_cost + cost + heuristic(node),
|
||
|
current.total_cost + cost,
|
||
|
node,
|
||
|
new_path
|
||
|
)
|
||
|
for (node, cost) in neighbors(current.node)
|
||
|
if node not in visited
|
||
|
]
|
||
|
),
|
||
|
visited
|
||
|
)
|
||
|
return pathfind_inner([_FrontierNode(0, 0, start_state, tuple())], set())
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
import doctest
|
||
|
doctest.testmod()
|