ai-lab-one/emis_funky_funktions.py

253 lines
6.4 KiB
Python
Raw Normal View History

2023-02-07 16:59:55 +00:00
from dataclasses import dataclass
2023-02-07 22:01:31 +00:00
from functools import partial, wraps
2023-02-10 01:46:55 +00:00
from operator import not_
from typing import Any, Callable, Concatenate, Generic, Iterator, ParamSpec, Sequence, Tuple, TypeVar
2023-02-07 16:59:55 +00:00
A = TypeVar('A')
B = TypeVar('B')
C = TypeVar('C')
D = TypeVar('D')
P = ParamSpec('P')
2023-02-07 22:01:31 +00:00
P1 = ParamSpec('P1')
P2 = ParamSpec('P2')
2023-02-07 16:59:55 +00:00
# Compose
2023-02-07 22:01:31 +00:00
def c(f2: Callable[[B], C], f1: Callable[P, B]) -> Callable[P, C]:
@wraps(f1)
def inner(*args: P.args, **kwargs: P.kwargs) -> C:
return f2(f1(*args, **kwargs))
return inner
2023-02-07 16:59:55 +00:00
# Flip: (A -> B -> C) -> B -> A -> C
2023-02-07 22:01:31 +00:00
def flip(f: Callable[P1, Callable[P2, C]]) -> Callable[P2, Callable[P1, C]]:
@wraps(f)
def inner1(*args2: P2.args, **kwargs2: P2.kwargs) -> Callable[P1, C]:
@wraps(f)
def inner2(*args1: P1.args, **kwargs1: P1.kwargs) -> C:
return f(*args1, **kwargs1)(*args2, **kwargs2)
return inner2
return inner1
2023-02-07 16:59:55 +00:00
2023-02-10 01:46:55 +00:00
# Identity function!
def ident(x: A) -> A:
return x
2023-02-07 16:59:55 +00:00
# Partial Appliaction shorthand
p = partial
# Two and three-argument currying
# Defining these pointfree fucks up the types btw
2023-02-07 22:01:31 +00:00
def cur2(f: Callable[Concatenate[A, P], C]) -> Callable[[A], Callable[P, C]]:
2023-02-07 16:59:55 +00:00
return p(p, f) #type:ignore
2023-02-07 22:01:31 +00:00
def cur3(f: Callable[Concatenate[A, B, P], D]) -> Callable[[A], Callable[[B], Callable[P, D]]]:
2023-02-07 16:59:55 +00:00
return p(p, p, f) #type:ignore
# Curried versions of map & filter with stricter types
def p_map(f: Callable[[A], B]) -> Callable[[Sequence[A]], Sequence[B]]:
return partial(map, f) #type: ignore
def p_filter(f: Callable[[A], bool]) -> Callable[[Sequence[A]], Sequence[A]]:
return partial(filter,f) #type: ignore
# Normal Accessors
@cur2
def indx(i: int, s: Sequence[A]) -> A:
return s[i]
fst = indx(0)
snd = indx(1)
# Semantic Editor Combinators
class SemEdComb:
class Inner():
def __init__(self, f: Callable, name: str):
self.f = f
self.name = name
def and_then(self, other: 'SemEdComb.Inner') -> 'SemEdComb.Inner':
return SemEdComb.Inner(c(other.f, self.f), self.name + ' and ' + other.name)
def __repr__(self) -> str:
return f"SemEdComb*({self.name})"
def __call__(self, *args, **kwargs):
return self.f(*args, **kwargs)
def __init__(self, f: Callable[[Callable],Callable], name: str):
self.f = f
self.name = name
def _c(self, next_f: Callable[[Callable], Callable], next_fname: str) -> 'SemEdComb':
return SemEdComb(c(self.f, next_f), self.name + next_fname)
RESULT = cur2(c)
ARG = flip(RESULT)
ALL = p_map
@cur3
@staticmethod
def INDEX(i, f, arr):
arr[i] = f(arr[i])
return arr
@cur3
@staticmethod
def INDEX_TUP(i: int, f: Callable[[Any], Any], tup: Tuple) -> Tuple:
l = list(tup)
l[i] = f(l[i])
return (*l,)
@cur2
@staticmethod
def FIRST(f: Callable[[A], C], tup: Tuple[A, B]) -> Tuple[C, B]:
return (f(tup[0]), tup[1])
@cur2
@staticmethod
def SECOND(f: Callable[[B], C], tup: Tuple[A, B]) -> Tuple[A, C]:
return (tup[0], f(tup[1]))
@property
def result(self) -> 'SemEdComb':
return self._c(SemEdComb.RESULT, '.result')
@property
def arg(self) -> 'SemEdComb':
return self._c(SemEdComb.ARG, '.arg')
@property
def all(self) -> 'SemEdComb':
return self._c(SemEdComb.ALL, '.all')
def index(self, i) -> 'SemEdComb':
return self._c(SemEdComb.INDEX(i), f'.index({i})')
def index_tup(self, i) -> 'SemEdComb':
return self._c(SemEdComb.INDEX_TUP(i), f'.index_tup({i})')
@property
def first(self) -> 'SemEdComb':
return self._c(SemEdComb.FIRST, f'.first')
@property
def second(self) -> 'SemEdComb':
return self._c(SemEdComb.SECOND, f'.second')
def __repr__(self):
return f"SemEdComb({self.name})"
def pmap(self, mapper):
return SemEdComb.Inner(self.f(mapper), self.name)
def map(self, mapper, thing_to_map) -> Callable:
return self.pmap(mapper)(thing_to_map)
def __call__(self, *args, **kwargs):
return self.f(*args, **kwargs)
result = SemEdComb(SemEdComb.RESULT, 'result')
arg = SemEdComb(SemEdComb.ARG, 'arg')
index = lambda i: SemEdComb(SemEdComb.INDEX(i), f'index({i})')
index_tup = lambda i: SemEdComb(SemEdComb.INDEX_TUP(i), f'index_tup({i})')
first = SemEdComb(SemEdComb.FIRST, 'first')
second = SemEdComb(SemEdComb.SECOND, 'second')
2023-02-07 22:01:46 +00:00
all_ = SemEdComb(SemEdComb.ALL, 'all')
2023-02-07 16:59:55 +00:00
# Tail call optimizing recursion
@dataclass
class Recur(Generic[P]):
def __init__(self, *args: P.args, **kwargs: P.kwargs):
self.args = args
self.kwargs = kwargs
@dataclass(frozen = True)
class Return(Generic[B]):
val: B
2023-02-07 22:01:31 +00:00
def tco_rec(f: Callable[P, Recur[P] | Return[B]]) -> Callable[P, B]:
@wraps(f)
def tco_loop(*args: P.args, **kwargs: P.kwargs) -> B:
while True:
match f(*args, **kwargs):
case Recur(args=args, kwargs=kwargs): #type:ignore
pass
case Return(val=val)|val:
return val #type:ignore
2023-02-09 02:32:24 +00:00
return tco_loop
# Options!
@dataclass(frozen=True)
class Some(Generic[A]):
val: A
Option = Some[A] | None
def map_opt(f: Callable[[A], B], o: Option[A]) -> Option[B]:
match o:
case Some(val):
return Some(f(val))
case none:
return none
def bind_opt(f: Callable[[A], Option[B]], o: Option[A]) -> Option[B]:
match o:
case Some(val):
return f(val)
case none:
return none
2023-02-09 02:44:43 +00:00
def note(e: B, o: Option[A]) -> Result[A, B]:
match o:
case Some(val):
return Ok(val)
case None:
return Err(e)
2023-02-09 02:32:24 +00:00
# Results!
@dataclass(frozen=True)
class Ok(Generic[A]):
val: A
@dataclass(frozen=True)
class Err(Generic[B]):
err: B
2023-02-10 01:46:55 +00:00
def __bool__(self):
return False
2023-02-09 02:32:24 +00:00
Result = Ok[A] | Err[B]
def map_res(f: Callable[[A], C], r: Result[A, B]) -> Result[C, B]:
match r:
case Ok(val):
return Ok(f(val))
case not_okay:
return not_okay
def bind_res(f: Callable[[A], Result[C, B]], r: Result[A, B]) -> Result[C, B]:
match r:
case Ok(val):
return f(val)
case not_okay:
2023-02-09 02:44:43 +00:00
return not_okay
def map_err(f: Callable[[B], C], r: Result[A, B]) -> Result[A, C]:
match r:
case Err(e):
return Err(f(e))
case oki_doke:
return oki_doke
def hush(r: Result[A, Any]) -> Option[A]:
match r:
case Ok(val):
return Some(val)
case not_okay:
2023-02-10 01:46:55 +00:00
return None
def try_(handle: Callable[[Exception], B], f: Callable[P, A], *args: P.args, **kwargs: P.kwargs) -> Result[A, B]:
try:
return Ok(f(*args, **kwargs))
except Exception as e:
return Err(handle(e))
def unwrap_r(r: Result[A, Any]) -> A:
match r:
case Ok(val):
return val
case Err(e):
raise Exception(f'Tried to unwrap an error: {e}')
def sequence(s: Sequence[Result[A, B]]) -> Result[Iterator[A], B]:
if all(s):
return Ok((
unwrap_r(r)
for r in s
))
else:
o = next(filter(not_, s))
assert isinstance(o, Err)
return o