ai-lab-one/a_star.py

196 lines
6.2 KiB
Python

from emis_funky_funktions import *
from dataclasses import dataclass, field
from heapq import heappop, heappush
from operator import eq
from typing import Callable, Generic, List, Sequence, Set, Tuple, TypeVar
S = TypeVar('S')
def _heappush_all(
heap: List[S],
new_items: Iterable[S]
) -> List[S]:
"A shorthand for calling `heappush` with several new items"
for item in new_items:
heappush(heap, item)
return heap
@dataclass(frozen=True, order=True)
class _FrontierNode(Generic[S]):
estimated_final_cost: int
total_cost: int
node: S = field(compare=False)
path: Tuple[S, ...] = field(compare=False)
def pathfind(
neighbors: Callable[[S], Sequence[Tuple[S, int]]],
heuristic: Callable[[S], int],
goal: Callable[[S], bool],
start_state: S
) -> Option[Tuple[List[S], int]]:
"""
Perform an A* search over an arbitrary search space
Arguments:
neighbors: Given a state, this function should return all neighboring states
along with the costs of moving to that space from the given state.
heuristic: Given a state, this function should estimate the cost of travelling
from that state to the goal state.
goal: Should return true only for the goal state.
start_state: The state that pathfinding should start from
Returns:
If no path is available:
None
If pathfinding succeeds:
A list including the path taken to get to the goal along with the total cost
of that path
Example:
Navigate from the top-left square to the top-right square, where the cost of
moving is the number.
>>> map = [
... [ 8, 1, 1, 1, 9, 1, 1, 0 ],
... [ 8, 1, 1, 1, 9, 1, 999, 1 ],
... [ 1, 1, 1, 1, 9, 1, 1, 1 ],
... [ 1, 1, 1, 1, 9, 1, 1, 1 ],
... [ 1, 1, 30, 1, 5, 1, 1, 999 ],
... [ 1, 1, 999, 1, 5, 1, 1, 1 ],
... [ 1, 1, 999, 1, 5, 1, 1, 1 ],
... [ 0, 1, 999, 1, 1, 1, 1, 1 ]
... ]
>>> neighbors = lambda l: [
... ((nx, ny), map[ny][nx]) # Tuple of (x, y) and the cost
... for (nx, ny) in (
... # Enumerate all adjacent squares (even illegal ones)
... (l[0] + dir_x, l[1] + dir_y)
... for (dir_x, dir_y) in [(-1, 0), (1, 0), (0, -1), (0, 1)]
... )
... if nx >= 0 and nx < 8 and ny >= 0 and ny < 8
... ]
>>> heuristic = lambda l: 7 - l[0] + l[1]
>>> goal = lambda l: l == (7, 0)
>>> pathfind(neighbors, heuristic, goal, (0, 7)) #doctest: +NORMALIZE_WHITESPACE
Some((((0, 7), (1, 7), (1, 6), (1, 5), (1, 4), (1, 3),
(2, 3), (3, 3), (3, 4), (4, 4), (5, 4), (6, 4),
(6, 3), (6, 2), (7, 2), (7, 1), (7, 0)), 19))
"""
frontier, visited = ([_FrontierNode(0, 0, start_state, tuple())], set())
while True:
# Don't look at this in mypy
# The types check out but mypy is REALLY bad at unifying types
match try_(ident, heappop, frontier):
case Err(_):
return None
case Ok(current) if current.node in visited:
pass #RECUR
case Ok(current):
new_path = (*current.path, current.node)
if goal(current.node):
return Some((new_path, current.total_cost))
else:
visited.add(current.node)
frontier, visited = ( #RECUR
_heappush_all(
frontier,
[
_FrontierNode(
current.total_cost + cost + heuristic(node),
current.total_cost + cost,
node,
new_path
)
for (node, cost) in neighbors(current.node)
if node not in visited
]
),
visited
)
@tco_rec
def pathfind_multi(
neighbors: Callable[[S], Sequence[Tuple[S, int]]],
heuristic: Callable[[S, S], int],
checkpoints: List[S],
prefix_moves: Tuple[Tuple[S, ...], int] = (tuple(), 0)
) -> Return[
Option[Tuple[Tuple[S, ...], int]]
] | Recur[[
Callable[[S], Sequence[Tuple[S, int]]],
Callable[[S, S], int],
List[S],
Tuple[Tuple[S, ...], int]
]]:
"""
Pathfind a path between a series of states in sequence
For each pair of adjacent nodes in the checkpoints list, a path between those two
nodes will be found. The returned path passes through each provided node in order.
>>> map = [
... [ 8, 1, 1, 1, 9, 1, 1, 0 ],
... [ 8, 1, 1, 1, 9, 1, 999, 1 ],
... [ 1, 1, 1, 1, 9, 1, 1, 1 ],
... [ 1, 1, 1, 1, 9, 1, 1, 1 ],
... [ 1, 1, 30, 1, 5, 1, 1, 999 ],
... [ 1, 1, 999, 1, 5, 1, 1, 1 ],
... [ 1, 1, 999, 1, 5, 1, 1, 1 ],
... [ 0, 1, 999, 1, 1, 1, 1, 1 ]
... ]
We re-use the neighbors & heuristic function we introduced in `pathfind()`.
>>> neighbors = lambda l: [
... ((nx, ny), map[ny][nx]) # Tuple of (x, y) and the cost
... for (nx, ny) in (
... # Enumerate all adjacent squares (even illegal ones)
... (l[0] + dir_x, l[1] + dir_y)
... for (dir_x, dir_y) in [(-1, 0), (1, 0), (0, -1), (0, 1)]
... )
... if nx >= 0 and nx < 8 and ny >= 0 and ny < 8
... ]
The heuristic function must provide a heuristic between two points, rather than a
heuristic based on single point, as in `pathfind()`. If your heuristic function is
asymmetric, note that the first argument is where we are pathing *to*, and the
second is where we are pathing *from*.
>>> heuristic = lambda f, t: abs(f[0] - t[0]) + abs(f[1] - t[1])
Now we pathfind from the bottom left corner, through the top left corner, then finish
in the bottom right.
>>> pathfind_multi(neighbors, heuristic, [(0, 7), (0, 0), (7, 7)]) #doctest: +NORMALIZE_WHITESPACE
Some((((0, 7), (0, 6), (0, 5), (0, 4), (0, 3),
(0, 2), (1, 2), (1, 1), (1, 0), (0, 0),
(1, 0), (2, 0), (3, 0), (3, 1), (3, 2),
(3, 3), (3, 4), (3, 5), (3, 6), (3, 7),
(4, 7), (5, 7), (6, 7), (7, 7)), 30))
"""
match checkpoints:
case []:
return Return(Some(prefix_moves))
case [single]:
return Return(Some(((*prefix_moves[0], single), prefix_moves[1])))
case [start, goal, *next_goals]:
match pathfind(neighbors, p(heuristic, goal), p(eq, goal), start):
case None:
return Return(None)
case Some((path, cost)):
return Recur(
neighbors,
heuristic,
[goal, *next_goals],
(
(*prefix_moves[0], *path[:-1]),
prefix_moves[1] + cost
)
)
if __name__ == '__main__':
import doctest
doctest.testmod()