196 lines
6.2 KiB
Python
196 lines
6.2 KiB
Python
from emis_funky_funktions import *
|
|
|
|
from dataclasses import dataclass, field
|
|
from heapq import heappop, heappush
|
|
from operator import eq
|
|
from typing import Callable, Generic, List, Sequence, Set, Tuple, TypeVar
|
|
|
|
S = TypeVar('S')
|
|
|
|
def _heappush_all(
|
|
heap: List[S],
|
|
new_items: Iterable[S]
|
|
) -> List[S]:
|
|
"A shorthand for calling `heappush` with several new items"
|
|
for item in new_items:
|
|
heappush(heap, item)
|
|
return heap
|
|
|
|
@dataclass(frozen=True, order=True)
|
|
class _FrontierNode(Generic[S]):
|
|
estimated_final_cost: int
|
|
total_cost: int
|
|
node: S = field(compare=False)
|
|
path: Tuple[S, ...] = field(compare=False)
|
|
|
|
def pathfind(
|
|
neighbors: Callable[[S], Sequence[Tuple[S, int]]],
|
|
heuristic: Callable[[S], int],
|
|
goal: Callable[[S], bool],
|
|
start_state: S
|
|
) -> Option[Tuple[List[S], int]]:
|
|
"""
|
|
Perform an A* search over an arbitrary search space
|
|
|
|
Arguments:
|
|
neighbors: Given a state, this function should return all neighboring states
|
|
along with the costs of moving to that space from the given state.
|
|
heuristic: Given a state, this function should estimate the cost of travelling
|
|
from that state to the goal state.
|
|
goal: Should return true only for the goal state.
|
|
start_state: The state that pathfinding should start from
|
|
|
|
Returns:
|
|
If no path is available:
|
|
None
|
|
If pathfinding succeeds:
|
|
A list including the path taken to get to the goal along with the total cost
|
|
of that path
|
|
|
|
Example:
|
|
Navigate from the top-left square to the top-right square, where the cost of
|
|
moving is the number.
|
|
|
|
>>> map = [
|
|
... [ 8, 1, 1, 1, 9, 1, 1, 0 ],
|
|
... [ 8, 1, 1, 1, 9, 1, 999, 1 ],
|
|
... [ 1, 1, 1, 1, 9, 1, 1, 1 ],
|
|
... [ 1, 1, 1, 1, 9, 1, 1, 1 ],
|
|
... [ 1, 1, 30, 1, 5, 1, 1, 999 ],
|
|
... [ 1, 1, 999, 1, 5, 1, 1, 1 ],
|
|
... [ 1, 1, 999, 1, 5, 1, 1, 1 ],
|
|
... [ 0, 1, 999, 1, 1, 1, 1, 1 ]
|
|
... ]
|
|
>>> neighbors = lambda l: [
|
|
... ((nx, ny), map[ny][nx]) # Tuple of (x, y) and the cost
|
|
... for (nx, ny) in (
|
|
... # Enumerate all adjacent squares (even illegal ones)
|
|
... (l[0] + dir_x, l[1] + dir_y)
|
|
... for (dir_x, dir_y) in [(-1, 0), (1, 0), (0, -1), (0, 1)]
|
|
... )
|
|
... if nx >= 0 and nx < 8 and ny >= 0 and ny < 8
|
|
... ]
|
|
>>> heuristic = lambda l: 7 - l[0] + l[1]
|
|
>>> goal = lambda l: l == (7, 0)
|
|
>>> pathfind(neighbors, heuristic, goal, (0, 7)) #doctest: +NORMALIZE_WHITESPACE
|
|
Some((((0, 7), (1, 7), (1, 6), (1, 5), (1, 4), (1, 3),
|
|
(2, 3), (3, 3), (3, 4), (4, 4), (5, 4), (6, 4),
|
|
(6, 3), (6, 2), (7, 2), (7, 1), (7, 0)), 19))
|
|
"""
|
|
frontier, visited = ([_FrontierNode(0, 0, start_state, tuple())], set())
|
|
while True:
|
|
# Don't look at this in mypy
|
|
# The types check out but mypy is REALLY bad at unifying types
|
|
match try_(ident, heappop, frontier):
|
|
case Err(_):
|
|
return None
|
|
case Ok(current) if current.node in visited:
|
|
pass #RECUR
|
|
case Ok(current):
|
|
new_path = (*current.path, current.node)
|
|
if goal(current.node):
|
|
return Some((new_path, current.total_cost))
|
|
else:
|
|
visited.add(current.node)
|
|
|
|
frontier, visited = ( #RECUR
|
|
_heappush_all(
|
|
frontier,
|
|
[
|
|
_FrontierNode(
|
|
current.total_cost + cost + heuristic(node),
|
|
current.total_cost + cost,
|
|
node,
|
|
new_path
|
|
)
|
|
for (node, cost) in neighbors(current.node)
|
|
if node not in visited
|
|
]
|
|
),
|
|
visited
|
|
)
|
|
|
|
@tco_rec
|
|
def pathfind_multi(
|
|
neighbors: Callable[[S], Sequence[Tuple[S, int]]],
|
|
heuristic: Callable[[S, S], int],
|
|
checkpoints: List[S],
|
|
prefix_moves: Tuple[Tuple[S, ...], int] = (tuple(), 0)
|
|
) -> Return[
|
|
Option[Tuple[Tuple[S, ...], int]]
|
|
] | Recur[[
|
|
Callable[[S], Sequence[Tuple[S, int]]],
|
|
Callable[[S, S], int],
|
|
List[S],
|
|
Tuple[Tuple[S, ...], int]
|
|
]]:
|
|
"""
|
|
Pathfind a path between a series of states in sequence
|
|
|
|
For each pair of adjacent nodes in the checkpoints list, a path between those two
|
|
nodes will be found. The returned path passes through each provided node in order.
|
|
|
|
>>> map = [
|
|
... [ 8, 1, 1, 1, 9, 1, 1, 0 ],
|
|
... [ 8, 1, 1, 1, 9, 1, 999, 1 ],
|
|
... [ 1, 1, 1, 1, 9, 1, 1, 1 ],
|
|
... [ 1, 1, 1, 1, 9, 1, 1, 1 ],
|
|
... [ 1, 1, 30, 1, 5, 1, 1, 999 ],
|
|
... [ 1, 1, 999, 1, 5, 1, 1, 1 ],
|
|
... [ 1, 1, 999, 1, 5, 1, 1, 1 ],
|
|
... [ 0, 1, 999, 1, 1, 1, 1, 1 ]
|
|
... ]
|
|
|
|
We re-use the neighbors & heuristic function we introduced in `pathfind()`.
|
|
|
|
>>> neighbors = lambda l: [
|
|
... ((nx, ny), map[ny][nx]) # Tuple of (x, y) and the cost
|
|
... for (nx, ny) in (
|
|
... # Enumerate all adjacent squares (even illegal ones)
|
|
... (l[0] + dir_x, l[1] + dir_y)
|
|
... for (dir_x, dir_y) in [(-1, 0), (1, 0), (0, -1), (0, 1)]
|
|
... )
|
|
... if nx >= 0 and nx < 8 and ny >= 0 and ny < 8
|
|
... ]
|
|
|
|
The heuristic function must provide a heuristic between two points, rather than a
|
|
heuristic based on single point, as in `pathfind()`. If your heuristic function is
|
|
asymmetric, note that the first argument is where we are pathing *to*, and the
|
|
second is where we are pathing *from*.
|
|
|
|
>>> heuristic = lambda f, t: abs(f[0] - t[0]) + abs(f[1] - t[1])
|
|
|
|
Now we pathfind from the bottom left corner, through the top left corner, then finish
|
|
in the bottom right.
|
|
|
|
>>> pathfind_multi(neighbors, heuristic, [(0, 7), (0, 0), (7, 7)]) #doctest: +NORMALIZE_WHITESPACE
|
|
Some((((0, 7), (0, 6), (0, 5), (0, 4), (0, 3),
|
|
(0, 2), (1, 2), (1, 1), (1, 0), (0, 0),
|
|
(1, 0), (2, 0), (3, 0), (3, 1), (3, 2),
|
|
(3, 3), (3, 4), (3, 5), (3, 6), (3, 7),
|
|
(4, 7), (5, 7), (6, 7), (7, 7)), 30))
|
|
"""
|
|
match checkpoints:
|
|
case []:
|
|
return Return(Some(prefix_moves))
|
|
case [single]:
|
|
return Return(Some(((*prefix_moves[0], single), prefix_moves[1])))
|
|
case [start, goal, *next_goals]:
|
|
match pathfind(neighbors, p(heuristic, goal), p(eq, goal), start):
|
|
case None:
|
|
return Return(None)
|
|
case Some((path, cost)):
|
|
return Recur(
|
|
neighbors,
|
|
heuristic,
|
|
[goal, *next_goals],
|
|
(
|
|
(*prefix_moves[0], *path[:-1]),
|
|
prefix_moves[1] + cost
|
|
)
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import doctest
|
|
doctest.testmod() |