from dataclasses import dataclass from functools import partial, wraps from operator import not_ from typing import Any, Callable, Concatenate, Generic, FrozenSet, Iterable, Iterator, List, Mapping, ParamSpec, Sequence, Tuple, Type, TypeGuard, TypeVar A = TypeVar('A') B = TypeVar('B') C = TypeVar('C') D = TypeVar('D') P = ParamSpec('P') P1 = ParamSpec('P1') P2 = ParamSpec('P2') # Compose def c(f2: Callable[[B], C], f1: Callable[P, B]) -> Callable[P, C]: """ Compose two functions by passing the output of the second to the input of the first. `c(f1, f2)(*args)` is equivalent to `f1(f2(*args))`. This can also be thought of as mapping the output of a function using the first parameter as a mapper function. >>> double = lambda x: x + x >>> succ = lambda x: x + 1 >>> c(double, succ)(1) 4 >>> c(succ, double)(1) 3 """ @wraps(f1) def inner(*args: P.args, **kwargs: P.kwargs) -> C: return f2(f1(*args, **kwargs)) return inner # Flip: (A -> B -> C) -> B -> A -> C def flip(f: Callable[P1, Callable[P2, C]]) -> Callable[P2, Callable[P1, C]]: """ Reverse the order of the first two arguments of a curried function. This only works with curried functions, so apply `cur2` or `cur3` before applying `flip` if the arguments you want to flip are not curried. >>> pair = lambda x: lambda y: (x, y) >>> pair(1)(2) (1, 2) >>> flip(pair)(1)(2) (2, 1) """ @wraps(f) def inner1(*args2: P2.args, **kwargs2: P2.kwargs) -> Callable[P1, C]: @wraps(f) def inner2(*args1: P1.args, **kwargs1: P1.kwargs) -> C: return f(*args1, **kwargs1)(*args2, **kwargs2) return inner2 return inner1 # Identity function! def ident(x: A) -> A: """ The identity function. Output is identical to input. >>> ident(3) 3 >>> ident(('hello', 8)) ('hello', 8) """ return x def k(replace_with: A) -> Callable[..., A]: """ Get a function which always returns a constant value, regardless of input The argument `replace_with` is the value the the returned function should always return. The returned function can be used as if having any arity, and will always return the same value originally passed to `replace`. >>> always_seven = k(7) >>> always_seven(2) 7 >>> always_seven('hello', 'world!') 7 >>> k('uwu')('NYA!') 'uwu' """ def constant(*args: Any, **kwargs: Any) -> A: "Always return a constant value, typically the one passed to `replace`" return replace_with return constant # Partial Appliaction shorthand p = partial "An alias for partial application" # Two and three-argument currying # Defining these pointfree fucks up the types btw def cur2(f: Callable[Concatenate[A, P], C]) -> Callable[[A], Callable[P, C]]: """ Perform two-argument currying. For example, a function from (A, B) -> C becomes a function A -> B -> C. This can also be though of as simply moving the first argument of a function out front, since it preserves any arguments after the first. That is, a function (A, B, C, kw=D) -> E becomes the function A -> (B, C, kw=D) -> E after being curried using this function. Can also be used as an annotation. >>> @cur2 ... def pair(x, y): ... return (x, y) ... >>> pair(1)(2) (1, 2) >>> alternate_pair = lambda x, y: (x, y) >>> cur2(alternate_pair)(1)(2) (1, 2) >>> threeple = lambda x, y, z: (x, y, z) >>> cur2(threeple)(1)(2, 3) (1, 2, 3) """ return p(p, f) #type:ignore def cur3(f: Callable[Concatenate[A, B, P], D]) -> Callable[[A], Callable[[B], Callable[P, D]]]: """ Perform three-argument currying. See `cur2` for an explaination of how this works. >>> threeple = lambda x, y, z: (x, y, z) >>> cur3(threeple)(1)(2)(3) (1, 2, 3) """ return p(p, p, f) #type:ignore def uncurry2(f: Callable[[A], Callable[P, B]]) -> Callable[Concatenate[A, P], B]: """ Uncurries a two-argument function The inverse of `cur2` >>> uncurry2(lambda x: lambda y: f'{x} {y}')('hello', 'world') 'hello world' """ @wraps(f) def inner(a: A, *args: P.args, **kwargs: P.kwargs) -> B: return f(a)(*args, **kwargs) return inner # Curried versions of map & filter with stricter types def p_map(f: Callable[[A], B]) -> Callable[[Sequence[A]], Sequence[B]]: "A curried version of the built in `map` function" return partial(map, f) #type: ignore def p_filter(f: Callable[[A], bool]) -> Callable[[Sequence[A]], Sequence[A]]: "A curried version of the built in `filter` function" return partial(filter,f) #type: ignore def p_instance(c: Type[A]) -> Callable[[Any], TypeGuard[A]]: "A curried version of the built in `is_instance` function" return flip(cur2(isinstance))(c) #type: ignore class FSet(FrozenSet[A]): "A subclass of FrozenSet with a more succinct and deterministic __repr__" def __repr__(self): """ >>> repr(FSet([1, 2, 3])) '{ 1, 2, 3 }' """ if len(self) == 0: return '{ }' else: return '{ ' + ', '.join(sorted([repr(e) for e in self])) + ' }' def fset(*args: A) -> FSet[A]: "Alias for `frozenset()` which uses varargs and returns `FSet`" return FSet(args) # Normal Accessors @cur2 def indx(i: int, s: Sequence[A]) -> A: """ A curried version of the getitem function >>> get_second = indx(1) >>> get_second(('a', 'b')) 'b' >>> get_second([1, 2, 3, 4]) 2 """ return s[i] fst = indx(0) "Get the first element of a tuple/sequence" snd = indx(1) "Get the second element of a tuple/sequence" def merge_with(conflict: Callable[[B, C], D], m1: Mapping[A, B], m2: Mapping[A, C]) -> Mapping[A, B | C | D]: """ Merge two mappings, handling conflicts with special behaviour. >>> merge_with(lambda a, b: a - b, {'a': 10, "b": 20, "c": 30}, {"b": 2, "c": 3, "d": 4}) == {'a': 10, 'b': 18, 'c': 27, 'd': 4} True """ return { key: ((conflict(m1[key], m2[key]) if key in m2 else m1[key]) if key in m1 else m2[key]) for key in (m1.keys() | m2.keys()) } # Semantic Editor Combinators class SemEdComb: """ A tool which approximates semantic editor combinators in python. Please read https://web.archive.org/web/20221202200001/http://conal.net/blog/posts/semantic-editor-combinators for context. Since Python has no infix function composition, using this pattern can get pretty ugly. This class abuses python's ability to override the property accessor (.) in order to approximate semantic editor combinators. >>> my_func = lambda x: ('abc' + x, 'def') >>> my_func('hi') ('abchi', 'def') >>> altered_func = result.first.map(str.upper, my_func) >>> altered_func('hi') ('ABCHI', 'def') >>> other_altered_func = arg.map(str.upper, my_func) >>> other_altered_func('hello') ('abcHELLO', 'def') Unfortunately, due to limitations of Python's type system, this class is largely untyped. """ class Inner(): "A chain of semantic editor combinators already paired with a map function" def __init__(self, f: Callable, name: str): self.f = f self.name = name def and_then(self, other: 'SemEdComb.Inner') -> 'SemEdComb.Inner': "Composes this with another `SemEdComb.Inner`" return SemEdComb.Inner(c(other.f, self.f), self.name + ' and ' + other.name) def __repr__(self) -> str: return f"SemEdComb*({self.name})" def __call__(self, *args, **kwargs): return self.f(*args, **kwargs) def __init__(self, f: Callable[[Callable],Callable], name: str): self.f = f self.name = name def _c(self, next_f: Callable[[Callable], Callable], next_fname: str) -> 'SemEdComb': return SemEdComb(c(self.f, next_f), self.name + next_fname) RESULT = cur2(c) "Map the result of a function" ARG = flip(RESULT) "Map the argument of a function" ALL = p_map "Map every element of a list" @cur3 @staticmethod def INDEX(i, f, arr): "Map the ith element of a mutable sequence" arr[i] = f(arr[i]) return arr @cur3 @staticmethod def INDEX_TUP(i: int, f: Callable[[Any], Any], tup: Tuple) -> Tuple: "Map the ith element of an immutable sequence" l = list(tup) l[i] = f(l[i]) return (*l,) @cur2 @staticmethod def FIRST(f: Callable[[A], C], tup: Tuple[A, B]) -> Tuple[C, B]: "Map the first element of a two-tuple" return (f(tup[0]), tup[1]) @cur2 @staticmethod def SECOND(f: Callable[[B], C], tup: Tuple[A, B]) -> Tuple[A, C]: "Map the second element of a two-tuple" return (tup[0], f(tup[1])) @property def result(self) -> 'SemEdComb': """ Map the result of a function >>> my_func = lambda s: s + ' backwards is ' + s[::-1] >>> my_func('hello') 'hello backwards is olleh' >>> altered_func = result.map(str.upper, my_func) >>> altered_func('hello') 'HELLO BACKWARDS IS OLLEH' Can be chained in order to work with curried functions as well. That is, the result of a two argument curried function is the result of the result of that function. >>> curried_pair = lambda x: lambda y: (x, y) >>> altered_pair = result.result.second.map(str.upper, curried_pair) >>> altered_pair('hello')('world') ('hello', 'WORLD') """ return self._c(SemEdComb.RESULT, '.result') @property def arg(self) -> 'SemEdComb': """ Map the argument of a function >>> my_func = lambda s: s + ' backwards is ' + s[::-1] >>> my_func('hello') 'hello backwards is olleh' >>> altered_func = arg.map(str.upper, my_func) >>> altered_func('hello') 'HELLO backwards is OLLEH' Can be combined with `.result` to work with curried functions. >>> curried_pair = lambda x: lambda y: (x, y) >>> altered_pair = result.arg.map(str.upper, curried_pair) >>> altered_pair('hello')('world') ('hello', 'WORLD') """ return self._c(SemEdComb.ARG, '.arg') @property def all(self) -> 'SemEdComb': """ Map every element of a sequence To use this as the base of a chain of SECs, write "all_", since "all" by itself refers to the builtin python function, which is different. Note that this returns an iterator, not a sequence, even if the thing being mapped was a sequence or a list. >>> list(all_.map(lambda x: x + x, [1, 2, 3])) [2, 4, 6] >>> my_func = lambda s: [s] * s >>> my_func(3) [3, 3, 3] >>> altered_func = result.all.map(lambda x: x + x, my_func) >>> list(altered_func(3)) [6, 6, 6] """ return self._c(SemEdComb.ALL, '.all') def index(self, i) -> 'SemEdComb': """ Map the ith element of a mutable sequence >>> index(1).map(lambda x: x + x, [1, 2, 3]) [1, 4, 3] >>> my_func = lambda s: [s] * s >>> my_func(3) [3, 3, 3] >>> altered_func = result.index(1).map(lambda x: x + x, my_func) >>> list(altered_func(3)) [3, 6, 3] """ return self._c(SemEdComb.INDEX(i), f'.index({i})') def index_tup(self, i) -> 'SemEdComb': """ Map the ith element of an immutable sequence. >>> index_tup(2).map(lambda x: x + x, (1, 2, 3, 4)) (1, 2, 6, 4) See Also: `index` For a more optimized version of this method specialized to two-tuples, see `first` and `second` """ return self._c(SemEdComb.INDEX_TUP(i), f'.index_tup({i})') @property def first(self) -> 'SemEdComb': """ Map the first element of a two-tuple >>> first.map(lambda x: x+x, (1, 2)) (2, 2) Doesn't work for threeples and fourples. If this is the behaviour you need, try `index_tup` >>> first.map(lambda x: x+x, (1, 2, 3)) (2, 2) """ return self._c(SemEdComb.FIRST, f'.first') @property def second(self) -> 'SemEdComb': """ Map the second element of a two-tuple >>> second.map(lambda x: x+x, (1, 2)) (1, 4) As with `first`, this doesn't work with threeples, fourples, and moreples. >>> second.map(lambda x: x+x, (1, 2, 3)) (1, 4) """ return self._c(SemEdComb.SECOND, f'.second') def __repr__(self): return f"SemEdComb({self.name})" def pmap(self, mapper): """ Set the mapper function, but don't call it yet The name is short for partial map. >>> my_func = lambda s1: lambda s2: f"You entered {s1} and the pair {s2}" >>> my_func(1)(('hello', 'world')) "You entered 1 and the pair ('hello', 'world')" >>> mapper = result.arg.first.pmap(str.upper) >>> altered_func = mapper(my_func) >>> altered_func(1)(('hello', 'world')) "You entered 1 and the pair ('HELLO', 'world')" See also: `map` """ return SemEdComb.Inner(self.f(mapper), self.name) def map(self, mapper, thing_to_map) -> Callable: "Apply the chain of combinators to a mapper and a mappee" return self.pmap(mapper)(thing_to_map) def __call__(self, *args, **kwargs): return self.f(*args, **kwargs) # Pre-constructed base semantic editor combinators result = SemEdComb(SemEdComb.RESULT, 'result') arg = SemEdComb(SemEdComb.ARG, 'arg') index = lambda i: SemEdComb(SemEdComb.INDEX(i), f'index({i})') index_tup = lambda i: SemEdComb(SemEdComb.INDEX_TUP(i), f'index_tup({i})') first = SemEdComb(SemEdComb.FIRST, 'first') second = SemEdComb(SemEdComb.SECOND, 'second') all_ = SemEdComb(SemEdComb.ALL, 'all') # Tail call optimizing recursion @dataclass class Recur(Generic[P]): """ Indicate that the function this is returned from should be called again with new args. Exclusively used with `tco_rec()` """ def __init__(self, *args: P.args, **kwargs: P.kwargs): self.args = args self.kwargs = kwargs @dataclass(frozen = True) class Return(Generic[B]): """ Indicate that the function this is returned from should return this value Exclusively used with `tco_rec()` """ val: B def tco_rec(f: Callable[P, Recur[P] | Return[B]]) -> Callable[P, B]: """ Run a tail-recursive function in a mannor which will not overflow the stack. Wraps a function in a loop which transforms its return type. The function is expected to return an instance of `Recur` rather than calling itself to recur. The arguments passed to the returned `Recur` instance become the arguments to the next iteration of the function call. When the function is ready to return for real, it should return an instance of `Return`. The function will be transformed by `tco_rec` to look as if it is a normal function. >>> @tco_rec ... def factorial(n, coefficient = 1): ... if n > 1: ... return Recur(n - 1, coefficient * n) ... else: ... return Return(coefficient) >>> factorial(4) 24 """ @wraps(f) def tco_loop(*args: P.args, **kwargs: P.kwargs) -> B: while True: match f(*args, **kwargs): case Recur(args=args, kwargs=kwargs): #type:ignore pass case Return(val=val)|val: return val #type:ignore return tco_loop # Options! @dataclass(frozen=True) class Some(Generic[A]): """ The positive part of an optional datatype Component of `Option` and counterpart of `None` """ val: A def __repr__(self) -> str: return f'Some({self.val!r})' Option = Some[A] | None "An Option datatype, aka Maybe" def map_opt(f: Callable[[A], B], o: Option[A]) -> Option[B]: """ Map the contents of an optional data type. Has no effect on `None` >>> map_opt(str.upper, Some('hello')) Some('HELLO') >>> map_opt(str.upper, None) is None True """ match o: case Some(val): return Some(f(val)) case none: return none def bind_opt(f: Callable[[A], Option[B]], o: Option[A]) -> Option[B]: """ wow! monads! (aka 'and_then') >>> halve = lambda n: Some(n//2) if n % 2 == 0 else None >>> [halve(2), halve(3)] [Some(1), None] >>> bind_opt(halve, Some(4)) Some(2) >>> bind_opt(halve, Some(5)) is None True >>> bind_opt(halve, None) is None True """ match o: case Some(val): return f(val) case none: return none def note(e: Callable[[], B], o: Option[A]) -> 'Result[A, B]': """ Convert an `Option` to a `Result` by attaching an error to the `None` variants `e` should be a zero-argument function which produces the desired error value. It will be called if and only if `o` is `None`. >>> note(lambda: 'woops!', Some(1)) Ok(1) >>> note(lambda: 'woops!', None) Err('woops!') """ match o: case Some(val): return Ok(val) case None: return Err(e()) def unwrap_opt(r: Option[A]) -> A: """ Assert that an `Option` is `Some` and return it's value. Throws: `AssertionError` - The result was NOT okay. The `AssertionError` will have two arguments: The first is a string to make it more obvious what happened. The second is the error that was stored in the `Err`. >>> unwrap_opt(Some('hai!')) 'hai!' >>> unwrap_opt(None) #doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): AssertionError: ('Tried to unwrap a None value') """ match r: case Some(val): return val case None: raise AssertionError('Tried to unwrap a None value') def drop_none(l: Iterable[Option[A]]) -> Sequence[A]: """ Drop every instance of `None` from a list, unwraping the rest >>> drop_none([Some(1), None, Some(2), None, Some(3)]) [1, 2, 3] """ return [o.val for o in l if o is not None] # Results! @dataclass(frozen=True) class Ok(Generic[A, B]): """ The positive part of a result (either) datatype Component of `Result` and counterpart of `Err` """ val: A def __repr__(self) -> str: return f'Ok({self.val!r})' def __lshift__(self, other: 'Callable[[A], Result[C, B]]') -> 'Result[C, B]': "Alias for bind" return other(self.val) def __le__(self, other: 'Callable[[A], C]') -> 'Result[C, B]': "Alias for map" return Ok(other(self.val)) @dataclass(frozen=True) class Err(Generic[A, B]): """ The error part of a result (either) datatype Component of `Result` and counterpart of `Ok` """ err: B def __repr__(self) -> str: return f'Err({self.err!r})' def __bool__(self): return False def __lshift__(self, other: 'Callable[[A], Result[C, B]]') -> 'Result[C, B]': "Alias for bind" """ Alias for bind >>> my_result = Err('oh noes!') >>> my_result <<(lambda x: Ok(x + 1)) Err('oh noes!') """ return self #type:ignore def __le__(self, other: 'Callable[[A], C]') -> 'Result[C, B]': "Alias for map" return self #type:ignore Result = Ok[A, B] | Err[A, B] "A Result datatype, aka Either" def map_res(f: Callable[[A], C], r: Result[A, B]) -> Result[C, B]: """ Map the success value of a result >>> map_res(str.upper, Ok('hai!')) Ok('HAI!') >>> map_res(str.upper, Err('oh noes')) Err('oh noes') """ match r: case Ok(val): return Ok(f(val)) case not_okay: return not_okay #type:ignore def bind_res(f: Callable[[A], Result[C, B]], r: Result[A, B]) -> Result[C, B]: """ Perform an fallible operation for successful results. >>> halve = lambda n: Ok(n//2) if n % 2 == 0 else Err(f'{n} is not divisible by 2') >>> [halve(2), halve(3)] [Ok(1), Err('3 is not divisible by 2')] >>> bind_res(halve, Ok(4)) Ok(2) >>> bind_res(halve, Ok(5)) Err('5 is not divisible by 2') >>> bind_res(halve, Err('not okay in the 1st place')) Err('not okay in the 1st place') """ match r: case Ok(val): return f(val) case not_okay: return not_okay #type:ignore def map_err(f: Callable[[B], C], r: Result[A, B]) -> Result[A, C]: """ Map the error value of a result >>> map_err(str.upper, Ok('hai!')) Ok('hai!') >>> map_err(str.upper, Err('oh noes')) Err('OH NOES') """ match r: case Err(e): return Err(f(e)) case oki_doke: return oki_doke #type:ignore def hush(r: Result[A, Any]) -> Option[A]: """ Convert a `Result` to an `Option` by converting any errors to `None` >>> hush(Ok('hai!')) Some('hai!') >>> hush(Err('oh noes')) is None True """ match r: case Ok(val): return Some(val) case not_okay: return None def try_(handle: Callable[[Exception], B], f: Callable[P, A], *args: P.args, **kwargs: P.kwargs) -> Result[A, B]: """ Try-catch in a function! Attempt to perform and operation, and `Err` on failure Arguments: handle - A function which handles any exceptions which arise. The return type is what will be wrapped into the resulting `Err`. This is not called if nothing goes wrong. f - The fallible function to try. If this succeeds without raising an error, that value is returned in an `Ok`. If this raises an exception, that exception will be passed to `handle`. args - Will be passed to `f` when it is called. kwargs - Will be passed to `f` when it is called. >>> try_(ident, int, '3') Ok(3) >>> try_(ident, int, 'three') Err(ValueError("invalid literal for int() with base 10: 'three'")) """ try: return Ok(f(*args, **kwargs)) except Exception as e: return Err(handle(e)) def try_converge( handle: Callable[[Exception], A], f: Callable[P, A], *args: P.args, **kwargs: P.kwargs ) -> A: """ Try-catch in a function! Attempt to perform and operation, and handle failure Arguments: handle - A function which handles any exceptions which arise. The return type of the handle should mirror the original return type of the function. f - The fallible function to try. If this succeeds without raising an error, that value is returned. If this raises an exception, that exception will be passed to `handle`. args - Will be passed to `f` when it is called. kwargs - Will be passed to `f` when it is called. >>> try_converge(k(-1), int, '3') 3 >>> try_converge(k(-1), int, 'three') -1 """ try: return f(*args, **kwargs) except Exception as e: return handle(e) def unwrap_r(r: Result[A, Any]) -> A: """ Assert that a `Result` is `Ok` and return it's value. Throws: `AssertionError` - The result was NOT okay. The `AssertionError` will have two arguments: The first is a string to make it more obvious what happened. The second is the error that was stored in the `Err`. >>> unwrap_r(Ok('hai!')) 'hai!' >>> unwrap_r(Err('oh noes')) is None #doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): AssertionError: ('Tried to unwrap an error: ', 'oh noes') """ match r: case Ok(val): return val case Err(e): raise AssertionError(f'Tried to unwrap an error: ', e) def sequence(s: Sequence[Result[A, B]]) -> Result[Sequence[A], B]: """ Convert a list of results into a result of a list. If the input sequence contains only `Ok` results, then the output is similarly `Ok`, and contains a list of all the unwrapped values of the `Ok`s. If there are any errors, proccessing of the sequence is immediately stopped, and the first error encountered is returned. >>> sequence([Ok(1), Ok(2), Ok(3)]) Ok([1, 2, 3]) >>> sequence([Ok(1), Err('Oops!'), Err('Aw man!')]) Err('Oops!') """ if all(s): return Ok(tuple(map(unwrap_r, s))) else: o = next(filter(not_, s)) assert isinstance(o, Err) return o #type:ignore def partition(s: Sequence[Result[A, B]]) -> tuple[Sequence[A], Sequence[B]]: """ Turn a list of results into a list of Ok values and a list of Err values >>> partition([Ok(1), Ok(2), Err('Aaaaa!'), Ok(4), Err('OH GOD HELP')]) ((1, 2, 4), ('Aaaaa!', 'OH GOD HELP')) """ return ( tuple( succ.val for succ in s if isinstance(succ, Ok) ), tuple( e.err for e in s if isinstance(e, Err) ), ) def trace(x: A) -> A: """ Print a value in passing Equivalent to the identity function **except** for the fact that it prints the value to the screen before returning. The value is printed with the prefix "TRACE:" to make it easy to see what printed. >>> trace(1 + 2) * 4 TRACE: 3 12 """ print(f'TRACE:', x) return x def profile(f: Callable[P, A]) -> Callable[P, A]: """ Wraps a function and check how long it takes to execute Returns a function which is identical to the input, but when called, attempts to record how long it takes to execute the function, and prints that information to the screen. >>> from time import sleep >>> profile(ident)(1) #doctest: +ELLIPSIS TIME OF ident(): ...ms 1 """ from time import perf_counter @wraps(f) def profiled(*args: P.args, **kwargs: P.kwargs) -> A: start_time = perf_counter() o = f(*args, **kwargs) stop_time = perf_counter() print(f'TIME OF {f.__name__}(): {1000 * (stop_time - start_time):.2f}ms') return o return profiled if __name__ == '__main__': import doctest doctest.testmod()