"""
An environment inspired by the game of Tetris.
"""
import itertools
import re
import warnings
from typing import List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import torch
from matplotlib.axes import Axes
from torchtyping import TensorType
from gflownet.envs.base import GFlowNetEnv
from gflownet.utils.common import set_device, tint
[docs]
PIECES = {
"I": [1, [[1], [1], [1], [1]]],
"J": [2, [[0, 2], [0, 2], [2, 2]]],
"L": [3, [[3, 0], [3, 0], [3, 3]]],
"O": [4, [[4, 4], [4, 4]]],
"S": [5, [[0, 5, 5], [5, 5, 0]]],
"T": [6, [[6, 6, 6], [0, 6, 0]]],
"Z": [7, [[7, 7, 0], [0, 7, 7]]],
}
[docs]
PIECES_COLORS = {
0: [255, 255, 255],
1: [19, 232, 232],
2: [30, 30, 201],
3: [240, 110, 2],
4: [236, 236, 14],
5: [0, 128, 0],
6: [125, 5, 126],
7: [236, 14, 14],
}
[docs]
class Tetris(GFlowNetEnv):
"""
Tetris environment: an environment inspired by the game of tetris. It's not
supposed to be a game, but rather a toy environment with an intuitive state and
action space.
The state space is 2D board, with all the combinations of pieces on it. Pieces that
are added to the board are identified by a number that starts from
piece_idx * max_pieces_per_type, and is incremented by 1 with each new piece from
the same type. This number fills in the cells of the board where the piece is
located. This enables telling apart pieces of the same type.
The action space is the choice of piece, its rotation and horizontal location
where to drop the piece. The action space may be constrained according to needs.
Attributes
----------
width : int
Width of the board.
height : int
Height of the board.
pieces : list
Pieces to use, identified by [I, J, L, O, S, T, Z]
rotations : list
Valid rotations, from [0, 90, 180, 270]
"""
def __init__(
self,
width: int = 10,
height: int = 20,
pieces: List = ["I", "J", "L", "O", "S", "T", "Z"],
rotations: List = [0, 90, 180, 270],
allow_redundant_rotations: bool = False,
allow_eos_before_full: bool = False,
**kwargs,
):
assert all([p in ["I", "J", "L", "O", "S", "T", "Z"] for p in pieces])
assert all([r in [0, 90, 180, 270] for r in rotations])
[docs]
self.device = set_device(kwargs["device"])
[docs]
self.rotations = rotations
[docs]
self.allow_redundant_rotations = allow_redundant_rotations
[docs]
self.allow_eos_before_full = allow_eos_before_full
[docs]
self.max_pieces_per_type = 100
# Helper functions and dicts
[docs]
self.piece2idx = lambda letter: PIECES[letter][0]
[docs]
self.idx2piece = {v[0]: k for k, v in PIECES.items()}
[docs]
self.piece2mat = lambda letter: tint(
PIECES[letter][1], int_type=self.int, device=self.device
)
[docs]
self.rot2idx = {0: 0, 90: 1, 180: 2, 270: 3}
# Check width and height compatibility
heights, widths = [], []
for piece in self.pieces:
for rotation in self.rotations:
piece_mat = torch.rot90(self.piece2mat(piece), k=self.rot2idx[rotation])
hp, wp = piece_mat.shape
heights.append(hp)
widths.append(wp)
assert all([self.height >= h for h in widths])
assert all([self.width >= w for w in widths])
# Source state: empty board
[docs]
self.source = torch.zeros(
(self.height, self.width), dtype=self.int, device=self.device
)
# End-of-sequence action: all -1
[docs]
self.eos = (-1, -1, -1)
# Precompute all possible rotations of each piece and the corresponding binary
# mask
[docs]
self.piece_rotation_mat = {}
[docs]
self.piece_rotation_mask_mat = {}
for p in pieces:
self.piece_rotation_mat[p] = {}
self.piece_rotation_mask_mat[p] = {}
for r in rotations:
self.piece_rotation_mat[p][r] = torch.rot90(
self.piece2mat(p), k=self.rot2idx[r]
)
self.piece_rotation_mask_mat[p][r] = self.piece_rotation_mat[p][r] != 0
# Base class init
super().__init__(**kwargs)
[docs]
def get_action_space(self):
"""
Constructs list with all possible actions, including eos. An action is
represented by a tuple of length 3 (piece, rotation, col). The piece is
represented by its index, the rotation by the integer rotation in degrees
and the location by horizontal cell in the board of the left-most part of the
piece.
"""
actions = []
pieces_mat = []
for piece in self.pieces:
for rotation in self.rotations:
piece_mat = torch.rot90(self.piece2mat(piece), k=self.rot2idx[rotation])
if self.allow_redundant_rotations or not any(
[torch.equal(p, piece_mat) for p in pieces_mat]
):
pieces_mat.append(piece_mat)
else:
continue
for col in range(self.width):
if col + piece_mat.shape[1] <= self.width:
actions.append((self.piece2idx(piece), rotation, col))
actions.append(self.eos)
return actions
def _drop_piece_on_board(
self, action, state: Optional[TensorType["height", "width"]] = None
):
"""
Drops a piece defined by the argument action onto the board. It returns an
updated board (copied) and a boolean variable, which is True if the piece can
be dropped onto the current and False otherwise.
"""
board = self._get_state(state, do_copy=True)
piece_idx, rotation, col = action
piece_mat = self.piece_rotation_mat[self.idx2piece[piece_idx]][rotation].clone()
piece_mat_mask = self.piece_rotation_mask_mat[self.idx2piece[piece_idx]][
rotation
].clone()
hp, wp = piece_mat.shape
# Check if piece goes overboard horizontally
if col + wp > self.width:
return board, False
# Find the highest row occupied by any other piece in the columns where we wish
# to add the new piece
occupied_rows = board[:, col : col + wp].sum(1).nonzero()
if len(occupied_rows) == 0:
# All rows are unoccupied. Set highest occupied row as the row "below" the
# board.
highest_occupied_row = self.height
else:
highest_occupied_row = occupied_rows[0, 0]
# Iteratively attempt to place piece on the board starting from the top.
# As soon as we reach a row where we can't place the piece because it
# creates a collision, we can stop the search (since we can't put a piece under
# this obstacle anyway).
starting_row = highest_occupied_row - hp
lowest_valid_row = None
for row in range(starting_row, self.height - hp + 1):
if row == -hp:
# Placing the piece here would make it land fully outside the board.
# This means that there is no place on the board for the piece
break
elif row < 0:
# It is not possible to place the piece at this row because the piece
# would not completely be in the board. However, it is still possible
# to check for obstacles because any obstacle will prevent placing the
# piece at any position below
board_section = board[: row + hp, col : col + wp]
piece_mask_section = piece_mat_mask[-(row + hp) :]
if (board_section * (piece_mask_section != 0)).any():
# An obstacle has been found.
break
else:
# The piece can be placed here if all board cells under piece are empty
board_section = board[row : row + hp, col : col + wp]
if (board_section * piece_mat_mask).any():
# The piece cannot be placed here and cannot be placed any lower
# because of an obstacle.
break
else:
# The piece can be placed here.
lowest_valid_row = row
# Place the piece if possible
if lowest_valid_row is None:
# The piece cannot be placed. Return the board as-is.
return board, False
else:
# Get and set index of new piece
piece_idx = self._get_max_piece_idx(board, piece_idx, incr=1)
piece_mat[piece_mat_mask] = piece_idx
# Place the piece on the board.
row = lowest_valid_row
board[row : row + hp, col : col + wp] += piece_mat
return board, True
[docs]
def get_mask_invalid_actions_forward(
self,
state: Optional[List] = None,
done: Optional[bool] = None,
) -> List:
"""
Returns a list of length the action space with values:
- True if the forward action is invalid from the current state.
- False otherwise.
"""
state = self._get_state(state)
done = self._get_done(done)
if done:
return [True for _ in range(self.policy_output_dim)]
mask = [False for _ in range(self.policy_output_dim)]
for idx, action in enumerate(self.action_space[:-1]):
_, valid = self._drop_piece_on_board(action, state)
if not valid:
mask[idx] = True
if not self.allow_eos_before_full and not all(mask[:-1]):
mask[-1] = True
return mask
[docs]
def states2proxy(
self,
states: Union[
List[TensorType["height", "width"]], TensorType["height", "width", "batch"]
],
) -> TensorType["height", "width", "batch"]:
"""
Prepares a batch of states in "environment format" for a proxy: : simply
converts non-zero (non-empty) cells into 1s.
Args
----
states : list of 2D tensors or 3D tensor
A batch of states in environment format, either as a list of states or as a
single tensor.
Returns
-------
A tensor containing all the states in the batch.
"""
states = tint(states, device=self.device, int_type=self.int)
states[states != 0] = 1
return states
[docs]
def states2policy(
self,
states: Union[
List[TensorType["height", "width"]], TensorType["height", "width", "batch"]
],
) -> TensorType["height", "width", "batch"]:
"""
Prepares a batch of states in "environment format" for the policy model.
See states2proxy().
Args
----
states : list of 2D tensors or 3D tensor
A batch of states in environment format, either as a list of states or as a
single tensor.
Returns
-------
A tensor containing all the states in the batch.
"""
states = tint(states, device=self.device, int_type=self.int)
return self.states2proxy(states).flatten(start_dim=1).to(self.float)
[docs]
def state2readable(self, state: Optional[TensorType["height", "width"]] = None):
"""
Converts a state (board) into a human-friendly string.
"""
state = self._get_state(state)
if isinstance(state, tuple):
readable = str(np.stack(state))
elif isinstance(state, list):
readable = str(np.array(state))
else:
readable = str(state.cpu().numpy())
readable = readable.replace("[[", "[").replace("]]", "]").replace("\n ", "\n")
return readable
[docs]
def readable2state(self, readable, alphabet={}):
"""
Converts a human-readable string representing a state into a state as a list of
positions.
"""
pattern = re.compile(r"\s+")
state = []
rows = readable.split("\n")
for row in rows:
# Preprocess
row = re.sub(pattern, " ", row)
row = row.replace(" ]", "]")
row = row.replace("[ ", "[")
# Process
state.append(
tint(
[int(el) for el in row.strip("[]").split(" ")],
int_type=self.int,
device=self.device,
)
)
return torch.stack(state)
[docs]
def get_parents(
self,
state: Optional[List] = None,
done: Optional[bool] = None,
action: Optional[Tuple] = None,
) -> Tuple[List, List]:
"""
Determines all parents and actions that lead to state.
See: _is_parent_action()
Args
----
state : list
Representation of a state, as a list of length length where each element is
the position at each dimension.
done : bool
Whether the trajectory is done. If None, done is taken from instance.
action : None
Ignored
Returns
-------
parents : list
List of parents in state format
actions : list
List of actions that lead to state for each parent in parents
"""
state = self._get_state(state)
done = self._get_done(done)
if done:
return [state], [self.eos]
else:
parents = []
actions = []
indices = state.unique()
for idx in indices[indices > 0]:
if self._piece_can_be_lifted(state, idx):
piece_idx, rotation, col = self._get_idx_rotation_col(state, idx)
parent = state.clone().detach()
parent[parent == idx] = 0
action = (piece_idx, rotation, col)
parents.append(parent)
actions.append(action)
return parents, actions
[docs]
def step(
self, action: Tuple[int], skip_mask_check: bool = False
) -> Tuple[List[int], Tuple[int], bool]:
"""
Executes step given an action.
Args
----
action : tuple
Action to be executed. An action is a tuple int values indicating the
dimensions to increment by 1.
skip_mask_check : bool
If True, skip computing forward mask of invalid actions to check if the
action is valid.
Returns
-------
self.state : list
The sequence after executing the action
action : tuple
Action executed
valid : bool
False, if the action is not allowed for the current state.
"""
# Generic pre-step checks
do_step, self.state, action = self._pre_step(
action, skip_mask_check or self.skip_mask_check
)
if not do_step:
return self.state, action, False
# If action is eos
if action == self.eos:
self.done = True
self.n_actions += 1
return self.state, self.eos, True
# If action is not eos, then perform action
else:
state_next, valid = self._drop_piece_on_board(action)
if valid:
self.state = state_next
self.n_actions += 1
return self.state, action, valid
def _get_max_trajectory_length(self) -> int:
"""
Returns the maximum trajectory length of the environment, including the EOS
action.
"""
return (self.width * self.height) // 4 + 1
[docs]
def set_state(
self, state: TensorType["height", "width"], done: Optional[bool] = False
):
"""
Sets the state and done. If done is True but incompatible with state (done is
True, allow_eos_before_full is False and state is not full), then force done
False and print warning. Also, make sure state is tensor.
"""
if not torch.is_tensor(state):
state = torch.tensor(state, dtype=torch.int16)
if done is True and not self.allow_eos_before_full:
mask = self.get_mask_invalid_actions_forward(state, done=False)
if not all(mask[:-1]):
done = False
warnings.warn(
f"Attempted to set state\n\n{self.state2readable(state)}\n\n"
"with done = True, which is not compatible with "
"allow_eos_before_full = False. Forcing done = False."
)
return super().set_state(state, done)
def _piece_can_be_lifted(self, board, piece_idx):
"""
Returns True if the piece with index piece_idx could be lifted, that is all
cells of the board above the piece are zeros. False otherwise.
"""
board_aux = board.clone().detach()
if piece_idx < self.max_pieces_per_type:
piece_idx = self._get_max_piece_idx(board_aux, piece_idx, incr=0)
rows, cols = torch.where(board_aux == piece_idx)
board_top = torch.cat([board[:r, c] for r, c in zip(rows, cols)])
board_top[board_top == piece_idx] = 0
return not any(board_top)
def _get_idx_rotation_col(self, board, piece_idx):
piece_idx_base = int(piece_idx / self.max_pieces_per_type)
board_aux = board.clone().detach()
piece_mat = self.piece2mat(self.idx2piece[piece_idx_base])
rows, cols = torch.where(board_aux == piece_idx)
row = min(rows).item()
col = min(cols).item()
hp = max(rows).item() - row + 1
wp = max(cols).item() - col + 1
board_section = board_aux[row : row + hp, col : col + wp]
board_section[board_section != piece_idx] = 0
board_section[board_section == piece_idx] = piece_idx_base
for rotation in self.rotations:
piece_mat_rot = torch.rot90(piece_mat, k=self.rot2idx[rotation])
if piece_mat_rot.shape == board_section.shape and torch.equal(
torch.rot90(piece_mat, k=self.rot2idx[rotation]), board_section
):
return piece_idx_base, rotation, col
raise ValueError(
f"No valid rotation found for piece {piece_idx} in board {board}"
)
def _get_max_piece_idx(
self, board: TensorType["height", "width"], piece_idx: int, incr: int = 0
):
"""
Gets the index of a new piece with base index piece_idx, based on the board.
board : tensor
The current board matrix.
piece_idx : int
Piece index, in base format [1, 2, ...]
incr : int
Increment of the returned index with respect to the max.
"""
min_idx = piece_idx * self.max_pieces_per_type
max_idx = min_idx + self.max_pieces_per_type
max_relevant_piece_idx = (board * (board < max_idx)).max()
if max_relevant_piece_idx >= min_idx:
return max_relevant_piece_idx + incr
else:
return min_idx
[docs]
def plot_samples_topk(
self,
samples: List,
rewards: TensorType["batch_size"],
k_top: int = 10,
n_rows: int = 2,
dpi: int = 150,
**kwargs,
):
"""
Plot tetris boards of top K samples.
Parameters
----------
samples : list
List of terminating states sampled from the policy.
rewards : list
Rewards of the samples.
k_top : int
The number of samples that will be included in the plot. The k_top samples
with the highest reward are selected.
n_rows : int
Number of rows in the plot. The number of columns will be calculated
according the n_rows and k_top.
dpi : int
DPI (dots per inch) of the figure, to determine the resolution.
"""
# Init figure
n_cols = np.ceil(k_top / n_rows).astype(int)
fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, dpi=dpi)
# Select top-k samples and plot them
rewards_topk, indices_topk = torch.sort(rewards, descending=True)[:k_top]
indices_topk = indices_topk.tolist()
for idx, ax in zip(indices_topk, axes.flatten()):
self._plot_board(samples[idx], ax)
fig.tight_layout()
return fig
@staticmethod
def _plot_board(board, ax: Axes, cellsize: int = 20, linewidth: int = 2):
"""
Plots a single Tetris board (a state).
Parameters
----------
board : tensor
State to plot.
ax : matplotlib Axes object
A matplotlib Axes object on which the board will be plotted.
cellsize : int
The size (length) of each board cell, in pixels.
linewidth : int
The width of the separation between cells, in pixels.
"""
board = board.clone().numpy()
height = board.shape[0] * cellsize
width = board.shape[1] * cellsize
board_img = 128 * np.ones(
(height + linewidth, width + linewidth, 3), dtype=np.uint8
)
for row in range(board.shape[0]):
for col in range(board.shape[1]):
row_init = row * cellsize + linewidth
row_end = row_init + cellsize - linewidth
col_init = col * cellsize + linewidth
col_end = col_init + cellsize - linewidth
color_key = int(board[row, col] / 100)
board_img[row_init:row_end, col_init:col_end, :] = PIECES_COLORS[
color_key
]
ax.imshow(board_img)
ax.set_axis_off()