2023-03-13 23:36:25 +01:00
|
|
|
#!/usr/bin/env python3
|
2023-03-13 23:31:20 +01:00
|
|
|
|
2023-03-15 18:20:23 +01:00
|
|
|
import itertools
|
|
|
|
from collections import deque
|
2023-03-20 23:46:10 +01:00
|
|
|
from enum import Enum
|
2023-03-15 18:20:23 +01:00
|
|
|
|
|
|
|
# Iterator recipes ----------------------------------------
|
|
|
|
def last(iterator):
|
2023-03-20 23:46:10 +01:00
|
|
|
"""Exhausts an iterator and returns its last element."""
|
2023-03-15 18:20:23 +01:00
|
|
|
return deque(iterator, maxlen=1)[0]
|
|
|
|
|
2023-03-20 23:46:10 +01:00
|
|
|
def exhaust(iterator):
|
|
|
|
"""Exhausts an iterator."""
|
|
|
|
deque(iterator, maxlen=0)
|
|
|
|
|
|
|
|
# Custom collections --------------------------------------
|
2023-03-15 18:20:23 +01:00
|
|
|
class defaultlist(list):
|
|
|
|
"""Default list class. Allows writing and reading out of bounds."""
|
|
|
|
def __init__(self, lst, val_factory):
|
|
|
|
super().__init__(lst)
|
|
|
|
self.val_factory = val_factory
|
|
|
|
|
|
|
|
def __setitem__(self, i, x):
|
|
|
|
for _ in range((i - len(self) + 1)):
|
|
|
|
self.append(self.val_factory())
|
|
|
|
super().__setitem__(i, x)
|
|
|
|
|
|
|
|
def __getitem__(self, i):
|
|
|
|
if i >= len(self):
|
|
|
|
return self.val_factory()
|
|
|
|
return super().__getitem__(i)
|
|
|
|
|
2023-03-20 23:46:10 +01:00
|
|
|
# Basic linear algebra ------------------------------------
|
|
|
|
class vector(tuple):
|
|
|
|
"""Vector class. Convert iterator to vector."""
|
|
|
|
def __neg__(self):
|
|
|
|
return vector(-x for x in self)
|
|
|
|
|
|
|
|
def __add__(self, other):
|
|
|
|
o = tuple(other)
|
|
|
|
if len(self) != len(o):
|
|
|
|
raise ValueError("vectors have different lengths")
|
|
|
|
return vector(x + y for x, y in zip(self, o))
|
|
|
|
|
|
|
|
def __sub__(self, other):
|
|
|
|
return self + (- vector(other))
|
|
|
|
|
|
|
|
def __mul__(self, other):
|
|
|
|
if type(other) == int:
|
|
|
|
return vector(other * x for x in self)
|
|
|
|
o = tuple(other)
|
|
|
|
if len(o) != len(self):
|
|
|
|
raise ValueError("vectors have different lengths")
|
|
|
|
return sum(x * y for x, y in zip(self, o))
|
|
|
|
|
|
|
|
def __truediv__(self, d):
|
|
|
|
return vector(x/d for x in self)
|
|
|
|
|
|
|
|
__radd__ = __add__
|
|
|
|
__rmul__ = __mul__
|
|
|
|
|
|
|
|
def __rsub__(self, other):
|
|
|
|
return other + (-self)
|
|
|
|
|
|
|
|
def rotate90pos(v):
|
|
|
|
"""Rotates vector v by 90 degrees in positive direction."""
|
|
|
|
x, y = v
|
|
|
|
return vector((-y, x))
|
2023-03-15 18:20:23 +01:00
|
|
|
|
2023-03-20 23:46:10 +01:00
|
|
|
def rotate90neg(v):
|
|
|
|
"""Rotates vector v by 90 degrees in negative direction."""
|
|
|
|
x, y = v
|
|
|
|
return vector((x, -y))
|
2023-03-14 12:44:57 +01:00
|
|
|
|
2023-03-20 23:46:10 +01:00
|
|
|
class Cardinals(Enum):
|
|
|
|
UP = vector((0, 1))
|
|
|
|
DOWN = vector((0, -1))
|
|
|
|
LEFT = vector((-1, 0))
|
|
|
|
RIGHT = vector((1, 0))
|
|
|
|
ORIGIN = vector((0, 0))
|
2023-03-14 12:44:57 +01:00
|
|
|
|
2023-03-20 23:46:10 +01:00
|
|
|
# Graphs ---------------------------------------------------
|
2023-03-13 23:31:20 +01:00
|
|
|
class Graph(dict):
|
|
|
|
nodes = dict.keys
|
|
|
|
|
|
|
|
def add_node(self, a):
|
|
|
|
if a not in self:
|
|
|
|
self[a] = []
|
|
|
|
|
|
|
|
def add_edge(self, a, b):
|
|
|
|
self.add_node(a)
|
|
|
|
self.add_node(b)
|
|
|
|
self[a].append(b)
|
|
|
|
self[b].append(a)
|
|
|
|
|
|
|
|
def nodes_of(self, a):
|
|
|
|
return self[a]
|
|
|
|
|
|
|
|
def rm_loops(self):
|
|
|
|
for a in self.keys():
|
|
|
|
for i, b in enumerate(self.nodes_of(a)):
|
|
|
|
if a == b:
|
|
|
|
del self[a][i]
|
|
|
|
|
|
|
|
def min_path(self, a, b):
|
|
|
|
weights = dict()
|
|
|
|
visited = set()
|
|
|
|
|
|
|
|
node, w = a, 0
|
|
|
|
while node != b:
|
|
|
|
for c in self[node]:
|
|
|
|
if c in visited:
|
|
|
|
continue
|
|
|
|
if c in weights and weights[c] <= w + 1:
|
|
|
|
continue
|
|
|
|
weights[c] = w + 1
|
|
|
|
node, w = min(weights.items() , key=lambda x: x[1])
|
|
|
|
del weights[node]
|
|
|
|
visited.add(node)
|
|
|
|
return w
|
|
|
|
|
|
|
|
def min_paths(self, a):
|
|
|
|
weights = dict()
|
|
|
|
visited = set()
|
|
|
|
visited.add(a)
|
|
|
|
|
|
|
|
node, w = a, 0
|
|
|
|
while len(visited) != len(self):
|
|
|
|
for c in self[node]:
|
|
|
|
if c in visited:
|
|
|
|
continue
|
|
|
|
if c in weights and weights[c] <= w + 1:
|
|
|
|
continue
|
|
|
|
weights[c] = w + 1
|
|
|
|
node, w = min(filter(lambda x: x[0] not in visited, weights.items()) , key=lambda x: x[1])
|
|
|
|
visited.add(node)
|
|
|
|
return weights
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WeightedGraph(Graph):
|
|
|
|
def nodes_of(self, a):
|
|
|
|
return [b for b, _ in self[a]]
|
|
|
|
|
|
|
|
def add_edge(self, a, b, w):
|
|
|
|
self[a].append((b, w))
|
|
|
|
self[b].append((a, w))
|
|
|
|
|
|
|
|
def min_path(self, a, b):
|
|
|
|
weights = dict()
|
|
|
|
visited = set()
|
|
|
|
|
|
|
|
node, w = a, 0
|
|
|
|
while node != b:
|
|
|
|
for c, edge_w in self[node]:
|
|
|
|
if c in visited:
|
|
|
|
continue
|
|
|
|
if c in weights and weights[c] <= w + edge_w:
|
|
|
|
continue
|
|
|
|
weights[c] = w + edge_w
|
|
|
|
node, w = min(weights.items() , key=lambda x: x[1])
|
|
|
|
del weights[node]
|
|
|
|
visited.add(node)
|
|
|
|
return w
|