Improve use of ParamSpec

This commit is contained in:
Emi Simpson 2023-02-07 17:01:31 -05:00
parent adddc0704d
commit 1e27ab5293
Signed by: Emi
GPG Key ID: A12F2C2FFDC3D847
1 changed files with 29 additions and 16 deletions

View File

@ -1,30 +1,41 @@
from dataclasses import dataclass
from functools import partial, reduce, wraps
from typing import Any, Callable, Generic, ParamSpec, Sequence, Tuple, TypeVar
from functools import partial, wraps
from typing import Any, Callable, Concatenate, Generic, ParamSpec, Sequence, Tuple, TypeVar
A = TypeVar('A')
B = TypeVar('B')
C = TypeVar('C')
D = TypeVar('D')
P = ParamSpec('P')
P1 = ParamSpec('P1')
P2 = ParamSpec('P2')
# Compose
def c(f2: Callable[[B], C], f1: Callable[[A], B]) -> Callable[[A], C]:
return lambda a: f2(f1(a))
def c(f2: Callable[[B], C], f1: Callable[P, B]) -> Callable[P, C]:
@wraps(f1)
def inner(*args: P.args, **kwargs: P.kwargs) -> C:
return f2(f1(*args, **kwargs))
return inner
# Flip: (A -> B -> C) -> B -> A -> C
def flip(f: Callable[[A],Callable[[B], C]]) -> Callable[[B], Callable[[A], C]]:
return wraps(f)(lambda b: wraps(f)(lambda a: f(a)(b)))
def flip(f: Callable[P1, Callable[P2, C]]) -> Callable[P2, Callable[P1, C]]:
@wraps(f)
def inner1(*args2: P2.args, **kwargs2: P2.kwargs) -> Callable[P1, C]:
@wraps(f)
def inner2(*args1: P1.args, **kwargs1: P1.kwargs) -> C:
return f(*args1, **kwargs1)(*args2, **kwargs2)
return inner2
return inner1
# Partial Appliaction shorthand
p = partial
# Two and three-argument currying
# Defining these pointfree fucks up the types btw
def cur2(f: Callable[[A, B], C]) -> Callable[[A], Callable[[B], C]]:
def cur2(f: Callable[Concatenate[A, P], C]) -> Callable[[A], Callable[P, C]]:
return p(p, f) #type:ignore
def cur3(f: Callable[[A, B, C], D]) -> Callable[[A], Callable[[B], Callable[[C], D]]]:
def cur3(f: Callable[Concatenate[A, B, P], D]) -> Callable[[A], Callable[[B], Callable[P, D]]]:
return p(p, p, f) #type:ignore
# Curried versions of map & filter with stricter types
@ -145,11 +156,13 @@ class Recur(Generic[P]):
class Return(Generic[B]):
val: B
@cur2
def tco_rec(f: Callable[P, Recur[P] | Return[B] | B], *args: P.args, **kwargs: P.kwargs) -> Callable[P, B]:
while True:
match f(*args, **kwargs):
case Recur(args=args, kwargs=kwargs): #type:ignore
pass
case Return(val=val)|val:
return val #type:ignore
def tco_rec(f: Callable[P, Recur[P] | Return[B]]) -> Callable[P, B]:
@wraps(f)
def tco_loop(*args: P.args, **kwargs: P.kwargs) -> B:
while True:
match f(*args, **kwargs):
case Recur(args=args, kwargs=kwargs): #type:ignore
pass
case Return(val=val)|val:
return val #type:ignore
return tco_loop