from dataclasses import dataclass from functools import partial, wraps from operator import not_ from typing import Any, Callable, Concatenate, Generic, Iterator, ParamSpec, Sequence, Tuple, TypeVar A = TypeVar('A') B = TypeVar('B') C = TypeVar('C') D = TypeVar('D') P = ParamSpec('P') P1 = ParamSpec('P1') P2 = ParamSpec('P2') # Compose def c(f2: Callable[[B], C], f1: Callable[P, B]) -> Callable[P, C]: @wraps(f1) def inner(*args: P.args, **kwargs: P.kwargs) -> C: return f2(f1(*args, **kwargs)) return inner # Flip: (A -> B -> C) -> B -> A -> C def flip(f: Callable[P1, Callable[P2, C]]) -> Callable[P2, Callable[P1, C]]: @wraps(f) def inner1(*args2: P2.args, **kwargs2: P2.kwargs) -> Callable[P1, C]: @wraps(f) def inner2(*args1: P1.args, **kwargs1: P1.kwargs) -> C: return f(*args1, **kwargs1)(*args2, **kwargs2) return inner2 return inner1 # Identity function! def ident(x: A) -> A: return x # Partial Appliaction shorthand p = partial # Two and three-argument currying # Defining these pointfree fucks up the types btw def cur2(f: Callable[Concatenate[A, P], C]) -> Callable[[A], Callable[P, C]]: return p(p, f) #type:ignore def cur3(f: Callable[Concatenate[A, B, P], D]) -> Callable[[A], Callable[[B], Callable[P, D]]]: return p(p, p, f) #type:ignore # Curried versions of map & filter with stricter types def p_map(f: Callable[[A], B]) -> Callable[[Sequence[A]], Sequence[B]]: return partial(map, f) #type: ignore def p_filter(f: Callable[[A], bool]) -> Callable[[Sequence[A]], Sequence[A]]: return partial(filter,f) #type: ignore # Normal Accessors @cur2 def indx(i: int, s: Sequence[A]) -> A: return s[i] fst = indx(0) snd = indx(1) # Semantic Editor Combinators class SemEdComb: class Inner(): def __init__(self, f: Callable, name: str): self.f = f self.name = name def and_then(self, other: 'SemEdComb.Inner') -> 'SemEdComb.Inner': return SemEdComb.Inner(c(other.f, self.f), self.name + ' and ' + other.name) def __repr__(self) -> str: return f"SemEdComb*({self.name})" def __call__(self, *args, **kwargs): return self.f(*args, **kwargs) def __init__(self, f: Callable[[Callable],Callable], name: str): self.f = f self.name = name def _c(self, next_f: Callable[[Callable], Callable], next_fname: str) -> 'SemEdComb': return SemEdComb(c(self.f, next_f), self.name + next_fname) RESULT = cur2(c) ARG = flip(RESULT) ALL = p_map @cur3 @staticmethod def INDEX(i, f, arr): arr[i] = f(arr[i]) return arr @cur3 @staticmethod def INDEX_TUP(i: int, f: Callable[[Any], Any], tup: Tuple) -> Tuple: l = list(tup) l[i] = f(l[i]) return (*l,) @cur2 @staticmethod def FIRST(f: Callable[[A], C], tup: Tuple[A, B]) -> Tuple[C, B]: return (f(tup[0]), tup[1]) @cur2 @staticmethod def SECOND(f: Callable[[B], C], tup: Tuple[A, B]) -> Tuple[A, C]: return (tup[0], f(tup[1])) @property def result(self) -> 'SemEdComb': return self._c(SemEdComb.RESULT, '.result') @property def arg(self) -> 'SemEdComb': return self._c(SemEdComb.ARG, '.arg') @property def all(self) -> 'SemEdComb': return self._c(SemEdComb.ALL, '.all') def index(self, i) -> 'SemEdComb': return self._c(SemEdComb.INDEX(i), f'.index({i})') def index_tup(self, i) -> 'SemEdComb': return self._c(SemEdComb.INDEX_TUP(i), f'.index_tup({i})') @property def first(self) -> 'SemEdComb': return self._c(SemEdComb.FIRST, f'.first') @property def second(self) -> 'SemEdComb': return self._c(SemEdComb.SECOND, f'.second') def __repr__(self): return f"SemEdComb({self.name})" def pmap(self, mapper): return SemEdComb.Inner(self.f(mapper), self.name) def map(self, mapper, thing_to_map) -> Callable: return self.pmap(mapper)(thing_to_map) def __call__(self, *args, **kwargs): return self.f(*args, **kwargs) result = SemEdComb(SemEdComb.RESULT, 'result') arg = SemEdComb(SemEdComb.ARG, 'arg') index = lambda i: SemEdComb(SemEdComb.INDEX(i), f'index({i})') index_tup = lambda i: SemEdComb(SemEdComb.INDEX_TUP(i), f'index_tup({i})') first = SemEdComb(SemEdComb.FIRST, 'first') second = SemEdComb(SemEdComb.SECOND, 'second') all_ = SemEdComb(SemEdComb.ALL, 'all') # Tail call optimizing recursion @dataclass class Recur(Generic[P]): def __init__(self, *args: P.args, **kwargs: P.kwargs): self.args = args self.kwargs = kwargs @dataclass(frozen = True) class Return(Generic[B]): val: B def tco_rec(f: Callable[P, Recur[P] | Return[B]]) -> Callable[P, B]: @wraps(f) def tco_loop(*args: P.args, **kwargs: P.kwargs) -> B: while True: match f(*args, **kwargs): case Recur(args=args, kwargs=kwargs): #type:ignore pass case Return(val=val)|val: return val #type:ignore return tco_loop # Options! @dataclass(frozen=True) class Some(Generic[A]): val: A Option = Some[A] | None def map_opt(f: Callable[[A], B], o: Option[A]) -> Option[B]: match o: case Some(val): return Some(f(val)) case none: return none def bind_opt(f: Callable[[A], Option[B]], o: Option[A]) -> Option[B]: match o: case Some(val): return f(val) case none: return none def note(e: B, o: Option[A]) -> Result[A, B]: match o: case Some(val): return Ok(val) case None: return Err(e) # Results! @dataclass(frozen=True) class Ok(Generic[A]): val: A @dataclass(frozen=True) class Err(Generic[B]): err: B def __bool__(self): return False Result = Ok[A] | Err[B] def map_res(f: Callable[[A], C], r: Result[A, B]) -> Result[C, B]: match r: case Ok(val): return Ok(f(val)) case not_okay: return not_okay def bind_res(f: Callable[[A], Result[C, B]], r: Result[A, B]) -> Result[C, B]: match r: case Ok(val): return f(val) case not_okay: return not_okay def map_err(f: Callable[[B], C], r: Result[A, B]) -> Result[A, C]: match r: case Err(e): return Err(f(e)) case oki_doke: return oki_doke def hush(r: Result[A, Any]) -> Option[A]: match r: case Ok(val): return Some(val) case not_okay: return None def try_(handle: Callable[[Exception], B], f: Callable[P, A], *args: P.args, **kwargs: P.kwargs) -> Result[A, B]: try: return Ok(f(*args, **kwargs)) except Exception as e: return Err(handle(e)) def unwrap_r(r: Result[A, Any]) -> A: match r: case Ok(val): return val case Err(e): raise Exception(f'Tried to unwrap an error: {e}') def sequence(s: Sequence[Result[A, B]]) -> Result[Iterator[A], B]: if all(s): return Ok(( unwrap_r(r) for r in s )) else: o = next(filter(not_, s)) assert isinstance(o, Err) return o