Refactored intcode.py. Added lib module.

master
Tibor Bizjak 2023-03-13 23:31:20 +01:00
parent 71fa1a9cd3
commit 4cb82e81e3
4 changed files with 195 additions and 106 deletions

View File

@ -1,4 +1,3 @@
from lib import format_main
from itertools import product from itertools import product
from operator import mul, add from operator import mul, add

View File

@ -1,5 +1,4 @@
from lib import format_main from intcode import Interpreter, makeIO, Singleton, Halted
from intcode import Interpreter, makeIO, Singleton
from itertools import permutations from itertools import permutations
stack_size = 5 stack_size = 5
@ -32,16 +31,14 @@ def partII(amp):
for perm in permutations(phase_range): for perm in permutations(phase_range):
amps = [amp() for _ in range(stack_size)] amps = [amp() for _ in range(stack_size)]
for p, a in zip(perm, amps): for p, a in zip(perm, amps):
a.write([p]) a.eval([p])
next(a)
amp_in = fst_amp_input amp_in = fst_amp_input
while True: while True:
try: try:
for a in amps: for a in amps:
a.write([amp_in]) amp_in = a.eval([amp_in])
amp_in = next(a) except Halted:
except StopIteration:
break break
best = max(amp_in, best) best = max(amp_in, best)

View File

@ -1,16 +1,13 @@
from collections import defaultdict from collections import defaultdict, deque
from enum import Enum from enum import Enum
from inspect import signature from inspect import signature
import operator import operator
class IntcodeError(Exception): class WaitForInput(Exception):
pass pass
class OpState(Enum): class Halted(Exception):
"""Execution state class.""" pass
CONTINUE = 0
HALT = 1
WAIT = 2
class defaultlist(list): class defaultlist(list):
"""Default list class. Allows writing and reading out of bounds.""" """Default list class. Allows writing and reading out of bounds."""
@ -28,6 +25,89 @@ class defaultlist(list):
return self.val_factory() return self.val_factory()
return super().__getitem__(i) return super().__getitem__(i)
class Emulator(object):
def __init__(self, program, get_input, put_output):
self.in_f = get_input
self.out_f = put_output
self.memory = defaultlist(program[:], val_factory = lambda : 0)
self.i = 0
self.rel_base = 0
def _of_operator(operator):
"""Make opcode from operator."""
def opcode(self, p1, p2, p3):
r = operator(self.memory[p1], self.memory[p2])
self.memory[p3] = r
self.i += 4
return opcode
def _jump_when(flag):
def opcode(self, p1, p2):
if (self.memory[p1] > 0) == flag:
self.i = self.memory[p2]
else:
self.i += 3
return opcode
def _get_input(self, p1):
x = self.in_f()
self.memory[p1] = x
self.i += 2
def _put_output(self, p1):
self.i += 2
self.out_f(self.memory[p1])
def _adjust_base(self, p1):
self.rel_base += self.memory[p1]
self.i += 2
opcodes = {1 : _of_operator(operator.add),
2 : _of_operator(operator.mul),
3 : _get_input,
4 : _put_output,
5 : _jump_when(True),
6 : _jump_when(False),
7 : _of_operator(operator.lt),
8 : _of_operator(operator.eq),
9 : _adjust_base,
}
def __next__(self):
state = self
op = str(self.memory[state.i])
if op == '99':
raise StopIteration
par_modes, op = op[:-2][::-1], int(op[-2:])
opcode = self.opcodes[op]
parnum = len(signature(opcode).parameters) - 1
par_modes = par_modes + '0'*(parnum - len(par_modes))
par_modes = map(int, par_modes)
pars = []
for pn, mode in enumerate(par_modes, start=1):
p = state.i + pn
if mode == 0:
p = state.memory[p]
elif mode == 2:
p = state.rel_base + state.memory[p]
pars.append(p)
opcode(state, *pars)
def __iter__(self):
return self
def run(self):
if self.memory[self.i] == 99:
raise Halted
return deque(self, maxlen=0)
class Singleton(object): class Singleton(object):
def __init__(self, x=None): def __init__(self, x=None):
self.x = x self.x = x
@ -64,7 +144,7 @@ def makeIO(in_buff_class, out_buff_class):
def pop_input(self): def pop_input(self):
if self.in_buff == in_buff_class(): if self.in_buff == in_buff_class():
return None raise WaitForInput
return self.in_buff.pop(0) return self.in_buff.pop(0)
def append_output(self, x): def append_output(self, x):
@ -94,98 +174,21 @@ def makeIO(in_buff_class, out_buff_class):
StackIO = makeIO(list, list) StackIO = makeIO(list, list)
SingletonIO = makeIO(Singleton, Singleton) SingletonIO = makeIO(Singleton, Singleton)
class Opcode:
def of_operator(operator):
"""Make opcode from operator."""
def opcode(state, p1, p2, p3):
r = operator(state.memory[p1], state.memory[p2])
state.memory[p3] = r
state.i += 4
return OpState.CONTINUE
return opcode
def jump_when(flag):
def opcode(state, p1, p2):
if (state.memory[p1] > 0) == flag:
state.i = state.memory[p2]
else:
state.i += 3
return OpState.CONTINUE
return opcode
def adjust_base(state, p1):
state.rel_base += state.memory[p1]
state.i += 2
return OpState.CONTINUE
def get_input(state, p1):
x = state.IO.pop_input()
if x == None:
return OpState.WAIT
state.memory[p1] = x
state.i += 2
return OpState.CONTINUE
def put_output(state, p1):
state.IO.append_output(state.memory[p1])
state.i += 2
return OpState.CONTINUE
halt = lambda _ : OpState.HALT
ops = {1 : of_operator(operator.add),
2 : of_operator(operator.mul),
3 : get_input,
4 : put_output,
5 : jump_when(True),
6 : jump_when(False),
7 : of_operator(operator.lt),
8 : of_operator(operator.eq),
9 : adjust_base,
99 : halt
}
def parse(op):
op = str(op)
par_modes, op = op[:-2][::-1], int(op[-2:])
opcode = Opcode.ops[op]
parnum = len(signature(opcode).parameters) - 1
par_modes = par_modes + '0'*(parnum - len(par_modes))
return opcode, map(int, par_modes)
def run(state):
opcode, par_modes = Opcode.parse(state.memory[state.i])
pars = []
for pn, mode in enumerate(par_modes, start=1):
p = state.i + pn
if mode == 0:
p = state.memory[p]
elif mode == 2:
p = state.rel_base + state.memory[p]
pars.append(p)
op_state = opcode(state, *pars)
return op_state
class Interpreter(object): class Interpreter(object):
def __init__(self, program, IO_class=StackIO, i=0, rel_base=0): def __init__(self, program, IO_class=StackIO):
self.memory = defaultlist(program[:], val_factory = lambda : 0)
self.IO = IO_class() self.IO = IO_class()
self.i = i self.comp = Emulator(program, self.IO.pop_input, self.IO.append_output)
self.rel_base = rel_base
def __iter__(self): def __iter__(self):
return self while True:
try:
def __next__(self): self.comp.run()
if Opcode.run(self) == OpState.HALT: break
raise StopIteration except WaitForInput:
while Opcode.run(self) == OpState.CONTINUE: yield self.IO.flush()
continue yield self.IO.flush()
return self.IO.flush()
def write(self, in_buff): def write(self, in_buff):
self.IO.write(in_buff) self.IO.write(in_buff)
@ -193,10 +196,17 @@ class Interpreter(object):
def run(self, in_buff=None): def run(self, in_buff=None):
if in_buff != None: if in_buff != None:
self.write(in_buff) self.write(in_buff)
out = next(self) self.comp.run()
if Opcode.run(self) != OpState.HALT: return self.IO.flush()
raise IntcodeError("expecting input")
return out def eval(self, in_buff):
self.write(in_buff)
try:
self.comp.run()
except WaitForInput:
pass
return self.IO.flush()
def copy(self): def copy(self):
memory = self.memory.copy() memory = self.memory.copy()

83
lib.py 100644
View File

@ -0,0 +1,83 @@
class Graph(dict):
nodes = dict.keys
def add_node(self, a):
if a not in self:
self[a] = []
def add_edge(self, a, b):
self.add_node(a)
self.add_node(b)
self[a].append(b)
self[b].append(a)
def nodes_of(self, a):
return self[a]
def rm_loops(self):
for a in self.keys():
for i, b in enumerate(self.nodes_of(a)):
if a == b:
del self[a][i]
def min_path(self, a, b):
weights = dict()
visited = set()
node, w = a, 0
while node != b:
for c in self[node]:
if c in visited:
continue
if c in weights and weights[c] <= w + 1:
continue
weights[c] = w + 1
node, w = min(weights.items() , key=lambda x: x[1])
del weights[node]
visited.add(node)
return w
def min_paths(self, a):
weights = dict()
visited = set()
visited.add(a)
node, w = a, 0
while len(visited) != len(self):
for c in self[node]:
if c in visited:
continue
if c in weights and weights[c] <= w + 1:
continue
weights[c] = w + 1
node, w = min(filter(lambda x: x[0] not in visited, weights.items()) , key=lambda x: x[1])
visited.add(node)
return weights
class WeightedGraph(Graph):
def nodes_of(self, a):
return [b for b, _ in self[a]]
def add_edge(self, a, b, w):
self[a].append((b, w))
self[b].append((a, w))
def min_path(self, a, b):
weights = dict()
visited = set()
node, w = a, 0
while node != b:
for c, edge_w in self[node]:
if c in visited:
continue
if c in weights and weights[c] <= w + edge_w:
continue
weights[c] = w + edge_w
node, w = min(weights.items() , key=lambda x: x[1])
del weights[node]
visited.add(node)
return w