gflownet.envs.tetris
An environment inspired by the game of Tetris.
Attributes
Classes
Tetris environment: an environment inspired by the game of tetris. It's not |
Module Contents
- class gflownet.envs.tetris.Tetris(width=10, height=20, pieces=['I', 'J', 'L', 'O', 'S', 'T', 'Z'], rotations=[0, 90, 180, 270], allow_redundant_rotations=False, allow_eos_before_full=False, **kwargs)[source]
Bases:
gflownet.envs.base.GFlowNetEnvTetris environment: an environment inspired by the game of tetris. It’s not supposed to be a game, but rather a toy environment with an intuitive state and action space.
The state space is 2D board, with all the combinations of pieces on it. Pieces that are added to the board are identified by a number that starts from piece_idx * max_pieces_per_type, and is incremented by 1 with each new piece from the same type. This number fills in the cells of the board where the piece is located. This enables telling apart pieces of the same type.
The action space is the choice of piece, its rotation and horizontal location where to drop the piece. The action space may be constrained according to needs.
- Parameters:
width (int)
height (int)
pieces (List)
rotations (List)
allow_redundant_rotations (bool)
allow_eos_before_full (bool)
- get_action_space()[source]
Constructs list with all possible actions, including eos. An action is represented by a tuple of length 3 (piece, rotation, col). The piece is represented by its index, the rotation by the integer rotation in degrees and the location by horizontal cell in the board of the left-most part of the piece.
- get_mask_invalid_actions_forward(state=None, done=None)[source]
- Returns a list of length the action space with values:
True if the forward action is invalid from the current state.
False otherwise.
- Parameters:
state (Optional[List])
done (Optional[bool])
- Return type:
List
- states2proxy(states)[source]
Prepares a batch of states in “environment format” for a proxy: : simply converts non-zero (non-empty) cells into 1s.
- Parameters:
states (list of 2D tensors or 3D tensor) – A batch of states in environment format, either as a list of states or as a single tensor.
- Returns:
A tensor containing all the states in the batch.
- Return type:
torchtyping.TensorType[height, width, batch]
- states2policy(states)[source]
Prepares a batch of states in “environment format” for the policy model.
See states2proxy().
- Parameters:
states (list of 2D tensors or 3D tensor) – A batch of states in environment format, either as a list of states or as a single tensor.
- Returns:
A tensor containing all the states in the batch.
- Return type:
torchtyping.TensorType[height, width, batch]
- state2readable(state=None)[source]
Converts a state (board) into a human-friendly string.
- Parameters:
state (Optional[torchtyping.TensorType[height, width]])
- readable2state(readable, alphabet={})[source]
Converts a human-readable string representing a state into a state as a list of positions.
- get_parents(state=None, done=None, action=None)[source]
Determines all parents and actions that lead to state.
See: _is_parent_action()
- Parameters:
state (list) – Representation of a state, as a list of length length where each element is the position at each dimension.
done (bool) – Whether the trajectory is done. If None, done is taken from instance.
action (None) – Ignored
- Returns:
parents (list) – List of parents in state format
actions (list) – List of actions that lead to state for each parent in parents
- Return type:
Tuple[List, List]
- step(action, skip_mask_check=False)[source]
Executes step given an action.
- Parameters:
action (tuple) – Action to be executed. An action is a tuple int values indicating the dimensions to increment by 1.
skip_mask_check (bool) – If True, skip computing forward mask of invalid actions to check if the action is valid.
- Returns:
self.state (list) – The sequence after executing the action
action (tuple) – Action executed
valid (bool) – False, if the action is not allowed for the current state.
- Return type:
Tuple[List[int], Tuple[int], bool]
- set_state(state, done=False)[source]
Sets the state and done. If done is True but incompatible with state (done is True, allow_eos_before_full is False and state is not full), then force done False and print warning. Also, make sure state is tensor.
- Parameters:
state (torchtyping.TensorType[height, width])
done (Optional[bool])
- plot_samples_topk(samples, rewards, k_top=10, n_rows=2, dpi=150, **kwargs)[source]
Plot tetris boards of top K samples.
- Parameters:
samples (list) – List of terminating states sampled from the policy.
rewards (list) – Rewards of the samples.
k_top (int) – The number of samples that will be included in the plot. The k_top samples with the highest reward are selected.
n_rows (int) – Number of rows in the plot. The number of columns will be calculated according the n_rows and k_top.
dpi (int) – DPI (dots per inch) of the figure, to determine the resolution.