#!/usr/bin/env python3 import itertools from collections import deque from enum import Enum # Random definitons --------------------------------------- class Color: WHITE = u"\u2B1C" BLACK = u"\u2B1B" # Iterator recipes ---------------------------------------- def last(iterator): """Exhausts an iterator and returns its last element.""" return deque(iterator, maxlen=1)[0] def exhaust(iterator): """Exhausts an iterator.""" deque(iterator, maxlen=0) # Custom collections -------------------------------------- class defaultlist(list): """Default list class. Allows writing and reading out of bounds.""" def __init__(self, lst, val_factory): super().__init__(lst) self.val_factory = val_factory def __setitem__(self, i, x): for _ in range((i - len(self) + 1)): self.append(self.val_factory()) super().__setitem__(i, x) def __getitem__(self, i): if i >= len(self): return self.val_factory() return super().__getitem__(i) # Basic linear algebra ------------------------------------ class vector(tuple): """Vector class. Convert iterator to vector.""" def __neg__(self): return vector(-x for x in self) def __add__(self, other): o = tuple(other) if len(self) != len(o): raise ValueError("vectors have different lengths") return vector(x + y for x, y in zip(self, o)) def __sub__(self, other): return self + (- vector(other)) def __mul__(self, other): if type(other) == int: return vector(other * x for x in self) o = tuple(other) if len(o) != len(self): raise ValueError("vectors have different lengths") return sum(x * y for x, y in zip(self, o)) def __truediv__(self, d): return vector(x/d for x in self) __radd__ = __add__ __rmul__ = __mul__ def __rsub__(self, other): return other + (-self) def rotate90pos(v): """Rotates vector v by 90 degrees in positive direction.""" x, y = v return vector((-y, x)) def rotate90neg(v): """Rotates vector v by 90 degrees in negative direction.""" x, y = v return vector((y, -x)) class Cardinals: UP = vector((0, 1)) DOWN = vector((0, -1)) LEFT = vector((-1, 0)) RIGHT = vector((1, 0)) ORIGIN = vector((0, 0)) # Graphs --------------------------------------------------- class Graph(dict): nodes = dict.keys def add_node(self, a): if a not in self: self[a] = [] def add_edge(self, a, b): self.add_node(a) self.add_node(b) self[a].append(b) self[b].append(a) def nodes_of(self, a): return self[a] def rm_loops(self): for a in self.keys(): for i, b in enumerate(self.nodes_of(a)): if a == b: del self[a][i] def min_path(self, a, b): weights = dict() visited = set() node, w = a, 0 while node != b: for c in self[node]: if c in visited: continue if c in weights and weights[c] <= w + 1: continue weights[c] = w + 1 node, w = min(weights.items() , key=lambda x: x[1]) del weights[node] visited.add(node) return w def min_paths(self, a): weights = dict() visited = set() visited.add(a) node, w = a, 0 while len(visited) != len(self): for c in self[node]: if c in visited: continue if c in weights and weights[c] <= w + 1: continue weights[c] = w + 1 node, w = min(filter(lambda x: x[0] not in visited, weights.items()) , key=lambda x: x[1]) visited.add(node) return weights class WeightedGraph(Graph): def nodes_of(self, a): return [b for b, _ in self[a]] def add_edge(self, a, b, w): self[a].append((b, w)) self[b].append((a, w)) def min_path(self, a, b): weights = dict() visited = set() node, w = a, 0 while node != b: for c, edge_w in self[node]: if c in visited: continue if c in weights and weights[c] <= w + edge_w: continue weights[c] = w + edge_w node, w = min(weights.items() , key=lambda x: x[1]) del weights[node] visited.add(node) return w