from emis_funky_funktions import *

from dataclasses import dataclass
from functools import wraps
from operator import contains
from typing import Callable, Collection, Mapping, TypeGuard, TypeAlias

def _expected(row: Mapping[B, Collection[Sequence[Any]]]) -> Collection[B]:
    """
    Given a single row from an oracle table, identify the expected terminals
    """
    return [
        terminal
        for terminal, expansions in row.items()
        if len(expansions)
    ]

Action: TypeAlias = Callable[[Sequence[C | D]], Sequence[C | D]]
def parser(
    oracle: Mapping[A, Mapping[B, Collection[Sequence[A | B | Action]]]],
    identify_lexeme: Callable[[D], B],
    start_symbol: A,
) -> Callable[[Sequence[D]], Result[Sequence[C | D], Tuple[D, Collection[B]]]]:
    """
    Produces a parser based on a grammar, an oracle, and a start symbol.

    The `identify_lexeme` argument should be a function which converts a lexeme into the
    token that it represents.  This allows for the actual lexemes that are being fed in to
    be more complex, and store additional data.

    The oracle table my include "action" annotations in its sequences.  Actions should be
    an instance of `Action`, and should work on the AST stack.  Every matched lexeme is
    pushed to the AST stack.  An action may transform this stack by popping some number of
    items off of it, constructing some AST, pushing that AST back to the stack, and then
    returning the modified stack.

    A couple things to note about this process:
        - The stack that is passed to each action is immutable.  "Modifications" should be
          made by simply constructing and returning a new stack.
        - The bottom of the stack is the zero index.

    If a parse is successful, the return value will be the AST stack at the end of the
    parse.  It is up the the caller to verify that this is an expected result.

    If a parse fails, the return value will be a tuple containing the erronious lexeme and
    a collection of expected tokens which failed to match it.

    ### Example:

    We generate a simple grammar:

    >>> class SimpleVariable(IntEnum):
    ...     S = auto()
    ...     Sum = auto()
    ...     Sum_ = auto()
    >>> class SimpleTerminal(IntEnum):
    ...     Number = auto()
    ...     Plus = auto()
    ...     Eof = auto()
    ...     def __repr__(self):
    ...         return self.name
    >>> build_S = lambda x: x[1:]
    >>> build_Sum = lambda x: (x[0](x[1][1]), *x[2:])
    >>> build_Sum_1 = lambda x: (lambda y: x[0] + y, *x[2:])
    >>> build_Sum_2 = lambda x: (lambda y: y, *x)
    >>> grammar = [
    ...     (SimpleVariable.S,    [SimpleVariable.Sum, SimpleTerminal.Eof, build_S]),
    ...     (SimpleVariable.Sum,  [SimpleTerminal.Number, SimpleVariable.Sum_, build_Sum]),
    ...     (SimpleVariable.Sum_, [SimpleTerminal.Plus, SimpleVariable.Sum, build_Sum_1]),
    ...     (SimpleVariable.Sum_, [build_Sum_2]),
    ... ]
    >>> is_term = p_instance(SimpleTerminal)
    >>> is_var = p_instance(SimpleVariable)
    >>> my_oracle_table = oracle_table(is_term, is_var, grammar)
    >>> my_parser = parser(my_oracle_table, lambda x: x[0], SimpleVariable.S)

    >>> my_parser([
    ...     (SimpleTerminal.Number, 1),
    ...     (SimpleTerminal.Plus,),
    ...     (SimpleTerminal.Number, 3),
    ...     (SimpleTerminal.Plus,),
    ...     (SimpleTerminal.Number, 10),
    ...     (SimpleTerminal.Eof,),
    ... ])
    Ok((14,))

    >>> my_parser([
    ...     (SimpleTerminal.Number, 1),
    ...     (SimpleTerminal.Plus,),
    ...     (SimpleTerminal.Number, 3),
    ...     (SimpleTerminal.Number, 10), # <--- this is invalid!
    ...     (SimpleTerminal.Eof,),
    ... ])
    Err(((Number, 10), [Plus, Eof]))
    """
    is_var: Callable[[Any], TypeGuard[A]] = p_instance(start_symbol.__class__)
    is_tok: Callable[[Any], TypeGuard[B]] = p_instance(next(iter(oracle[start_symbol].keys())).__class__)
    def inner(
        stack: Sequence[A | B | Action],
        ast_stack: Sequence[C | D],
        lexemes: Sequence[D],
    ) -> Result[Sequence[C | D], Tuple[D, Collection[B]]]:
        match stack:
            # A [Variable]
            case [top_of_stack, *popped_stack] if is_var(top_of_stack):
                try:
                    expansions = oracle[top_of_stack][identify_lexeme(lexemes[0])]
                except IndexError:
                    raise Exception('Unexpected end of input.  Expected:', _expected(oracle[top_of_stack]))
                match expansions:
                    case []:
                        return Err((lexemes[0], _expected(oracle[top_of_stack])))
                    case [expansion]:
                        return inner((*expansion, *popped_stack), ast_stack, lexemes)
                    case _:
                        raise Exception('Not an LL(1) grammar!!!')
            # B [Token] (match)
            case [top_of_stack, *popped_stack] if is_tok(top_of_stack) and  top_of_stack == identify_lexeme(lexemes[0]):
                return inner(popped_stack, (lexemes[0], *ast_stack), lexemes[1:])
            # B [Token] (no match)
            case [top_of_stack, *popped_stack] if is_tok(top_of_stack):
                assert is_tok(top_of_stack)
                return Err((lexemes[0], (top_of_stack,)))
            # Action
            case [f, *popped_stack]:
                assert hasattr(f, '__call__')
                return inner(popped_stack, f(ast_stack), lexemes)
            # Empty stack (finished parsing)
            case []:
                if len(lexemes):
                    return Err((lexemes[0], []))
                else:
                    return Ok(ast_stack)
        raise Exception('Unreachable!')
    return wraps(parser)(p(inner, [start_symbol], []))

if __name__ == '__main__':
    import doctest
    from enum import auto, IntEnum
    from build_oracle import oracle_table
    doctest.testmod()