ai-lab-one/a_star.py

117 lines
3.7 KiB
Python
Raw Normal View History

2023-02-11 01:44:23 +00:00
from emis_funky_funktions import *
from dataclasses import dataclass, field
from heapq import heappop, heappush
from typing import Callable, Generic, List, Sequence, Set, Tuple, TypeVar
S = TypeVar('S')
def _heappush_all(
heap: List[S],
new_items: Iterable[S]
) -> List[S]:
"A shorthand for calling `heappush` with several new items"
for item in new_items:
heappush(heap, item)
return heap
@dataclass(frozen=True, order=True)
class _FrontierNode(Generic[S]):
estimated_final_cost: int
total_cost: int
node: S = field(compare=False)
path: Tuple[S, ...] = field(compare=False)
def pathfind(
neighbors: Callable[[S], Sequence[Tuple[S, int]]],
heuristic: Callable[[S], int],
goal: Callable[[S], bool],
start_state: S
) -> Option[Tuple[List[S], int]]:
"""
Perform an A* search over an arbitrary search space
Arguments:
neighbors: Given a state, this function should return all neighboring states
along with the costs of moving to that space from the given state.
heuristic: Given a state, this function should estimate the cost of travelling
from that state to the goal state.
goal: Should return true only for the goal state.
start_state: The state that pathfinding should start from
Returns:
If no path is available:
None
If pathfinding succeeds:
A list including the path taken to get to the goal along with the total cost
of that path
Example:
Navigate from the top-left square to the top-right square, where the cost of
moving is the number.
>>> map = [
... [ 8, 1, 1, 1, 9, 1, 1, 0 ],
... [ 8, 1, 1, 1, 9, 1, 999, 1 ],
... [ 1, 1, 1, 1, 9, 1, 1, 1 ],
... [ 1, 1, 1, 1, 9, 1, 1, 1 ],
... [ 1, 1, 30, 1, 5, 1, 1, 999 ],
... [ 1, 1, 999, 1, 5, 1, 1, 1 ],
... [ 1, 1, 999, 1, 5, 1, 1, 1 ],
... [ 0, 1, 999, 1, 1, 1, 1, 1 ]
... ]
>>> neighbors = lambda l: [
... ((nx, ny), map[ny][nx]) # Tuple of (x, y) and the cost
... for (nx, ny) in (
... # Enumerate all adjacent squares (even illegal ones)
... (l[0] + dir_x, l[1] + dir_y)
... for (dir_x, dir_y) in [(-1, 0), (1, 0), (0, -1), (0, 1)]
... )
... if nx >= 0 and nx < 8 and ny >= 0 and ny < 8
... ]
>>> heuristic = lambda l: 7 - l[0] + l[1]
>>> goal = lambda l: l == (7, 0)
>>> pathfind(neighbors, heuristic, goal, (0, 7)) #doctest: +NORMALIZE_WHITESPACE
Some((((0, 7), (1, 7), (1, 6), (1, 5), (1, 4), (1, 3),
(2, 3), (3, 3), (3, 4), (4, 4), (5, 4), (6, 4),
(6, 3), (6, 2), (7, 2), (7, 1), (7, 0)), 19))
"""
@tco_rec
def pathfind_inner(
frontier: List[_FrontierNode[S]],
visited: Set[S]
) -> Return[Option[Tuple[Tuple[S, ...], int]]] | Recur[[List[_FrontierNode[S]], Set[S]]]:
# Don't look at this in mypy
# The types check out but mypy is REALLY bad at unifying types
match try_(ident, heappop, frontier):
case Err(_):
return Return(None)
case Ok(current) if current.node in visited:
return Recur(frontier, visited)
case Ok(current):
new_path = (*current.path, current.node)
if goal(current.node):
return Return(Some((new_path, current.total_cost)))
else:
visited.add(current.node)
return Recur(
_heappush_all(
frontier,
[
_FrontierNode(
current.total_cost + cost + heuristic(node),
current.total_cost + cost,
node,
new_path
)
for (node, cost) in neighbors(current.node)
if node not in visited
]
),
visited
)
return pathfind_inner([_FrontierNode(0, 0, start_state, tuple())], set())
if __name__ == '__main__':
import doctest
doctest.testmod()