Improve use of ParamSpec

This commit is contained in:
Emi Simpson 2023-02-07 17:01:31 -05:00
parent adddc0704d
commit 1e27ab5293
Signed by: Emi
GPG key ID: A12F2C2FFDC3D847

View file

@ -1,30 +1,41 @@
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial, reduce, wraps from functools import partial, wraps
from typing import Any, Callable, Generic, ParamSpec, Sequence, Tuple, TypeVar from typing import Any, Callable, Concatenate, Generic, ParamSpec, Sequence, Tuple, TypeVar
A = TypeVar('A') A = TypeVar('A')
B = TypeVar('B') B = TypeVar('B')
C = TypeVar('C') C = TypeVar('C')
D = TypeVar('D') D = TypeVar('D')
P = ParamSpec('P') P = ParamSpec('P')
P1 = ParamSpec('P1')
P2 = ParamSpec('P2')
# Compose # Compose
def c(f2: Callable[[B], C], f1: Callable[[A], B]) -> Callable[[A], C]: def c(f2: Callable[[B], C], f1: Callable[P, B]) -> Callable[P, C]:
return lambda a: f2(f1(a)) @wraps(f1)
def inner(*args: P.args, **kwargs: P.kwargs) -> C:
return f2(f1(*args, **kwargs))
return inner
# Flip: (A -> B -> C) -> B -> A -> C # Flip: (A -> B -> C) -> B -> A -> C
def flip(f: Callable[[A],Callable[[B], C]]) -> Callable[[B], Callable[[A], C]]: def flip(f: Callable[P1, Callable[P2, C]]) -> Callable[P2, Callable[P1, C]]:
return wraps(f)(lambda b: wraps(f)(lambda a: f(a)(b))) @wraps(f)
def inner1(*args2: P2.args, **kwargs2: P2.kwargs) -> Callable[P1, C]:
@wraps(f)
def inner2(*args1: P1.args, **kwargs1: P1.kwargs) -> C:
return f(*args1, **kwargs1)(*args2, **kwargs2)
return inner2
return inner1
# Partial Appliaction shorthand # Partial Appliaction shorthand
p = partial p = partial
# Two and three-argument currying # Two and three-argument currying
# Defining these pointfree fucks up the types btw # Defining these pointfree fucks up the types btw
def cur2(f: Callable[[A, B], C]) -> Callable[[A], Callable[[B], C]]: def cur2(f: Callable[Concatenate[A, P], C]) -> Callable[[A], Callable[P, C]]:
return p(p, f) #type:ignore return p(p, f) #type:ignore
def cur3(f: Callable[[A, B, C], D]) -> Callable[[A], Callable[[B], Callable[[C], D]]]: def cur3(f: Callable[Concatenate[A, B, P], D]) -> Callable[[A], Callable[[B], Callable[P, D]]]:
return p(p, p, f) #type:ignore return p(p, p, f) #type:ignore
# Curried versions of map & filter with stricter types # Curried versions of map & filter with stricter types
@ -145,11 +156,13 @@ class Recur(Generic[P]):
class Return(Generic[B]): class Return(Generic[B]):
val: B val: B
@cur2 def tco_rec(f: Callable[P, Recur[P] | Return[B]]) -> Callable[P, B]:
def tco_rec(f: Callable[P, Recur[P] | Return[B] | B], *args: P.args, **kwargs: P.kwargs) -> Callable[P, B]: @wraps(f)
while True: def tco_loop(*args: P.args, **kwargs: P.kwargs) -> B:
match f(*args, **kwargs): while True:
case Recur(args=args, kwargs=kwargs): #type:ignore match f(*args, **kwargs):
pass case Recur(args=args, kwargs=kwargs): #type:ignore
case Return(val=val)|val: pass
return val #type:ignore case Return(val=val)|val:
return val #type:ignore
return tco_loop