Improve use of ParamSpec
This commit is contained in:
parent
adddc0704d
commit
1e27ab5293
|
@ -1,30 +1,41 @@
|
|||
from dataclasses import dataclass
|
||||
from functools import partial, reduce, wraps
|
||||
from typing import Any, Callable, Generic, ParamSpec, Sequence, Tuple, TypeVar
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Concatenate, Generic, ParamSpec, Sequence, Tuple, TypeVar
|
||||
|
||||
A = TypeVar('A')
|
||||
B = TypeVar('B')
|
||||
C = TypeVar('C')
|
||||
D = TypeVar('D')
|
||||
P = ParamSpec('P')
|
||||
P1 = ParamSpec('P1')
|
||||
P2 = ParamSpec('P2')
|
||||
|
||||
|
||||
# Compose
|
||||
def c(f2: Callable[[B], C], f1: Callable[[A], B]) -> Callable[[A], C]:
|
||||
return lambda a: f2(f1(a))
|
||||
def c(f2: Callable[[B], C], f1: Callable[P, B]) -> Callable[P, C]:
|
||||
@wraps(f1)
|
||||
def inner(*args: P.args, **kwargs: P.kwargs) -> C:
|
||||
return f2(f1(*args, **kwargs))
|
||||
return inner
|
||||
|
||||
# Flip: (A -> B -> C) -> B -> A -> C
|
||||
def flip(f: Callable[[A],Callable[[B], C]]) -> Callable[[B], Callable[[A], C]]:
|
||||
return wraps(f)(lambda b: wraps(f)(lambda a: f(a)(b)))
|
||||
def flip(f: Callable[P1, Callable[P2, C]]) -> Callable[P2, Callable[P1, C]]:
|
||||
@wraps(f)
|
||||
def inner1(*args2: P2.args, **kwargs2: P2.kwargs) -> Callable[P1, C]:
|
||||
@wraps(f)
|
||||
def inner2(*args1: P1.args, **kwargs1: P1.kwargs) -> C:
|
||||
return f(*args1, **kwargs1)(*args2, **kwargs2)
|
||||
return inner2
|
||||
return inner1
|
||||
|
||||
# Partial Appliaction shorthand
|
||||
p = partial
|
||||
|
||||
# Two and three-argument currying
|
||||
# Defining these pointfree fucks up the types btw
|
||||
def cur2(f: Callable[[A, B], C]) -> Callable[[A], Callable[[B], C]]:
|
||||
def cur2(f: Callable[Concatenate[A, P], C]) -> Callable[[A], Callable[P, C]]:
|
||||
return p(p, f) #type:ignore
|
||||
def cur3(f: Callable[[A, B, C], D]) -> Callable[[A], Callable[[B], Callable[[C], D]]]:
|
||||
def cur3(f: Callable[Concatenate[A, B, P], D]) -> Callable[[A], Callable[[B], Callable[P, D]]]:
|
||||
return p(p, p, f) #type:ignore
|
||||
|
||||
# Curried versions of map & filter with stricter types
|
||||
|
@ -145,11 +156,13 @@ class Recur(Generic[P]):
|
|||
class Return(Generic[B]):
|
||||
val: B
|
||||
|
||||
@cur2
|
||||
def tco_rec(f: Callable[P, Recur[P] | Return[B] | B], *args: P.args, **kwargs: P.kwargs) -> Callable[P, B]:
|
||||
while True:
|
||||
match f(*args, **kwargs):
|
||||
case Recur(args=args, kwargs=kwargs): #type:ignore
|
||||
pass
|
||||
case Return(val=val)|val:
|
||||
return val #type:ignore
|
||||
def tco_rec(f: Callable[P, Recur[P] | Return[B]]) -> Callable[P, B]:
|
||||
@wraps(f)
|
||||
def tco_loop(*args: P.args, **kwargs: P.kwargs) -> B:
|
||||
while True:
|
||||
match f(*args, **kwargs):
|
||||
case Recur(args=args, kwargs=kwargs): #type:ignore
|
||||
pass
|
||||
case Return(val=val)|val:
|
||||
return val #type:ignore
|
||||
return tco_loop
|
Loading…
Reference in a new issue