diff --git a/chess.py b/chess.py index 6cb9f36..8380936 100644 --- a/chess.py +++ b/chess.py @@ -198,9 +198,9 @@ class Game: if piece == "pawn": r = [] - dir = up + dir, back = up, down if color == "black": - dir = down + dir, back = down, up frwd = move(sq, dir) jump = move(frwd, dir) @@ -209,7 +209,12 @@ class Game: is_on_pawn_rank = pawn_ranks[color] == sq[1] if is_on_pawn_rank and is_empty(jump): r.append(jump) - return r + [sq for sq in targets if can_eat(sq)] + for t in targets: + a, b = move(t, dir), move(t, back) + en_passant = can_eat(b) and self.moves[-1] == (a,b) + if can_eat(t) or en_passant: + r.append(t) + return r elif piece == "king": return [sq for sq in targets if not self.is_attacked(color, sq)] else: @@ -265,7 +270,7 @@ def test(): game = Game() assert len(squares) == 8**2 assert sum(map(len, init_positions.values())) == 8*4 - moves = [("a2", "a4"), ("b1", "c3"), ("c3", "b5"), ("b8", "c6")] + moves = [("a2", "a4"), ("b8", "c6"), ("a4", "a5"), ("b7", "b5"), ("a5", "b6")] for m in moves: game.move(*m) print (game)