diff --git a/lib.py b/lib.py index 4e5956d..102846a 100644 --- a/lib.py +++ b/lib.py @@ -2,11 +2,18 @@ import itertools from collections import deque +from enum import Enum # Iterator recipes ---------------------------------------- def last(iterator): + """Exhausts an iterator and returns its last element.""" return deque(iterator, maxlen=1)[0] +def exhaust(iterator): + """Exhausts an iterator.""" + deque(iterator, maxlen=0) + +# Custom collections -------------------------------------- class defaultlist(list): """Default list class. Allows writing and reading out of bounds.""" def __init__(self, lst, val_factory): @@ -23,17 +30,56 @@ class defaultlist(list): return self.val_factory() return super().__getitem__(i) +# Basic linear algebra ------------------------------------ +class vector(tuple): + """Vector class. Convert iterator to vector.""" + def __neg__(self): + return vector(-x for x in self) -def memoize(f): - cache = dict() - def memf(*args): - key = tuple(args) - if key not in cache: - cache[key] = f(*args) - return cache[key] - return memf + def __add__(self, other): + o = tuple(other) + if len(self) != len(o): + raise ValueError("vectors have different lengths") + return vector(x + y for x, y in zip(self, o)) + def __sub__(self, other): + return self + (- vector(other)) + def __mul__(self, other): + if type(other) == int: + return vector(other * x for x in self) + o = tuple(other) + if len(o) != len(self): + raise ValueError("vectors have different lengths") + return sum(x * y for x, y in zip(self, o)) + + def __truediv__(self, d): + return vector(x/d for x in self) + + __radd__ = __add__ + __rmul__ = __mul__ + + def __rsub__(self, other): + return other + (-self) + +def rotate90pos(v): + """Rotates vector v by 90 degrees in positive direction.""" + x, y = v + return vector((-y, x)) + +def rotate90neg(v): + """Rotates vector v by 90 degrees in negative direction.""" + x, y = v + return vector((x, -y)) + +class Cardinals(Enum): + UP = vector((0, 1)) + DOWN = vector((0, -1)) + LEFT = vector((-1, 0)) + RIGHT = vector((1, 0)) + ORIGIN = vector((0, 0)) + +# Graphs --------------------------------------------------- class Graph(dict): nodes = dict.keys