gflownet.envs.tree
IMPORTANT: this environment is not up to date.
Classes
Encodes two types of nodes present in a tree: |
|
Operator based on which the decision is made (< or >=). |
|
Status of the node. Every node except the one on which a macro step |
|
Current stage of the tree, encoded as part of the state. |
|
Type of action that will be passed to Tree.step. Refer to Stage for details. |
|
Contains indices of individual attributes in a state tensor. |
|
Module Contents
- class gflownet.envs.tree.NodeType[source]
Encodes two types of nodes present in a tree: 0 - condition node (node other than leaf), that stores the information
about on which feature to make the decision, and using what threshold.
- 1 - classifier node (leaf), that stores the information about class output
that will be predicted once that node is reached.
- class gflownet.envs.tree.Operator[source]
Operator based on which the decision is made (< or >=).
We assume the convention of having left child output label = 0 and right child label = 1 if operator is <, and the opposite if operator is >=. That way, during prediction, we can act as if the operator was always the same (and only care about the output label).
- class gflownet.envs.tree.Status[source]
Status of the node. Every node except the one on which a macro step was initiated will be marked as inactive; the node will be marked as active iff the process of its splitting is in progress.
- class gflownet.envs.tree.Stage[source]
Current stage of the tree, encoded as part of the state. 0 - complete, indicates that there is no macro step initiated, and
the only allowed action is to pick one of the leaves for splitting.
- 1 - leaf, indicates that a leaf was picked for splitting, and the only
allowed action is picking a feature on which it will be split.
- 2 - feature, indicates that a feature was picked, and the only allowed
action is picking a threshold for splitting.
- 3 - threshold, indicates that a threshold was picked, and the only
allowed action is picking an operator.
- 4 - operator, indicates that operator was picked. The only allowed
action from here is finalizing the splitting process and spawning two new leaves, which should be done automatically, upon which the stage should be changed to complete.
- class gflownet.envs.tree.ActionType[source]
Type of action that will be passed to Tree.step. Refer to Stage for details.
- class gflownet.envs.tree.Attribute[source]
Contains indices of individual attributes in a state tensor.
Types of attributes defining each node of the tree:
0 - node type (condition or classifier), 1 - index of the feature used for splitting (condition node only, -1 otherwise), 2 - decision threshold (condition node only, -1 otherwise), 3 - class output (classifier node only, -1 otherwise), in the case of < operator
the left child will have class = 0, and the right child will have class = 1; the opposite for the >= operator,
- 4 - whether the node has active status (1 if node was picked and the macro step
didn’t finish yet, 0 otherwise).
- class gflownet.envs.tree.Tree(X_train=None, y_train=None, X_test=None, y_test=None, data_path=None, scale_data=True, max_depth=10, continuous=True, n_thresholds=9, threshold_components=1, beta_params_min=1.0, beta_params_max=2.0, fixed_distr_params={'beta_alpha': 2.0, 'beta_beta': 5.0}, random_distr_params={'beta_alpha': 1.0, 'beta_beta': 1.0}, policy_format='mlp', test_args={'top_k_trees': 0}, **kwargs)[source]
Bases:
gflownet.envs.base.GFlowNetEnv- Parameters:
X_train (Optional[numpy.typing.NDArray])
y_train (Optional[numpy.typing.NDArray])
X_test (Optional[numpy.typing.NDArray])
y_test (Optional[numpy.typing.NDArray])
data_path (Optional[str])
scale_data (bool)
max_depth (int)
continuous (bool)
n_thresholds (Optional[int])
threshold_components (int)
beta_params_min (float)
beta_params_max (float)
fixed_distr_params (dict)
random_distr_params (dict)
policy_format (str)
test_args (dict)
- X_train
Train dataset, with dimensionality (n_observations, n_features). It may be None if a data set is provided via data_path.
- Type:
np.array
- y_train[source]
Train labels, with dimensionality (n_observations,). It may be None if a data set is provided via data_path.
- Type:
np.array
- X_test
Test dataset, with dimensionality (n_observations, n_features). It may be None if a data set is provided via data_path, or if you don’t want to perform test set evaluation.
- Type:
np.array
- y_train[source]
Test labels, with dimensionality (n_observations,). It may be None if a data set is provided via data_path, or if you don’t want to perform test set evaluation.
- Type:
np.array
- data_path
A path to a data set, with the following options: - *.pkl: Pickled dict with X_train, y_train, and (optional) X_test and y_test
variables.
*.csv: CSV containing an optional ‘Split’ column in the last place, containing ‘train’ and ‘test’ values, and M remaining columns, where the first (M - 1) columns will be taken to construct the input X, and M-th column will be the target y.
Ignored if X_train and y_train are not None.
- Type:
str
- scale_data
Whether to perform min-max scaling on the provided data (to a [0; 1] range).
- Type:
bool
- continuous[source]
Whether the environment should operate in a continuous mode (in which distribution parameters are predicted for the threshold) or the discrete mode (in which there is a discrete set of possible thresholds to choose from).
- Type:
bool
- n_thresholds
Number of uniformly distributed thresholds in a (0; 1) range that will be used in the discrete mode. Ignored if continuous is True.
- Type:
int
- policy_format
Type of policy that will be used with the environment, either ‘mlp’ or ‘gnn’. Influences which state2policy functions will be used.
- Type:
str
- threshold_components
The number of mixture components that will be used for sampling the threshold.
- Type:
int
- get_action_space()[source]
- Actions are a tuple containing:
- action type:
0 - pick leaf to split, 1 - pick feature, 2 - pick threshold, 3 - pick operator,
node index,
- action value, depending on the action type:
pick leaf: current class output, pick feature: feature index, pick threshold: threshold value, pick operator: operator index.
- Return type:
List[Tuple[int, int, int]]
- step(action, skip_mask_check=False)[source]
Executes step given an action.
- Parameters:
action (tuple) – Action to be executed. See: self.get_action_space()
skip_mask_check (bool) – If True, skip computing forward mask of invalid actions to check if the action is valid.
- Returns:
self.state (list) – The sequence after executing the action
action (tuple) – Action executed
valid (bool) – False, if the action is not allowed for the current state.
- Return type:
Tuple[List[int], Tuple[int, int, Union[int, float]], bool]
- step_backwards(action, skip_mask_check=False)[source]
Executes a backward step given an action.
- Parameters:
action (tuple) – Action from the action space.
skip_mask_check (bool) – If True, skip computing forward mask of invalid actions to check if the action is valid.
- Returns:
self.state (list) – The state after executing the action.
action (int) – Given action.
valid (bool) – False, if the action is not allowed for the current state.
- Return type:
Tuple[List[int], Tuple[int], bool]
- set_state(state, done=False)[source]
Sets the state and done. If done is True but incompatible with state (Stage is not COMPLETE), then force done False and print warning.
- Parameters:
state (List)
done (Optional[bool])
- sample_actions_batch_continuous(policy_outputs, mask=None, states_from=None, is_backward=False, random_action_prob=0.0, temperature_logits=1.0)[source]
Samples a batch of actions from a batch of policy outputs in the continuous mode.
- Parameters:
policy_outputs (torchtyping.TensorType[Tree.sample_actions_batch_continuous.n_states, policy_output_dim])
mask (Optional[torchtyping.TensorType[Tree.sample_actions_batch_continuous.n_states, policy_output_dim]])
states_from (Optional[List])
is_backward (Optional[bool])
random_action_prob (Optional[float])
temperature_logits (Optional[float])
- Return type:
Tuple[List[Tuple], torchtyping.TensorType[Tree.sample_actions_batch_continuous.n_states]]
- sample_actions_batch(policy_outputs, mask=None, states_from=None, is_backward=False, random_action_prob=0.0, temperature_logits=1.0)[source]
Samples a batch of actions from a batch of policy outputs.
- Parameters:
policy_outputs (torchtyping.TensorType[n_states, policy_output_dim])
mask (Optional[torchtyping.TensorType[n_states, policy_output_dim]])
states_from (Optional[List])
is_backward (Optional[bool])
random_action_prob (Optional[float])
temperature_logits (Optional[float])
- Return type:
Tuple[List[Tuple], torchtyping.TensorType[n_states]]
- get_logprobs_continuous(policy_outputs, actions, mask=None, states_from=None, is_backward=False)[source]
Computes log probabilities of actions given policy outputs and actions.
- Parameters:
policy_outputs (torchtyping.TensorType[Tree.get_logprobs_continuous.n_states, policy_output_dim])
actions (Union[List, torchtyping.TensorType[Tree.get_logprobs_continuous.n_states, action_dim]])
mask (torchtyping.TensorType[Tree.get_logprobs_continuous.n_states, 1])
states_from (Optional[List])
is_backward (bool)
- Return type:
torchtyping.TensorType[batch_size]
- get_logprobs(policy_outputs, actions, mask=None, states_from=None, is_backward=False)[source]
Computes log probabilities of actions given policy outputs and actions.
- Parameters:
policy_outputs (torchtyping.TensorType[n_states, policy_output_dim])
actions (Union[List, torchtyping.TensorType[n_states, action_dim]])
mask (torchtyping.TensorType[n_states, 1])
states_from (Optional[List])
is_backward (bool)
- Return type:
torchtyping.TensorType[batch_size]
- states2policy_mlp(states)[source]
Prepares a batch of states in torch “GFlowNet format” for an MLP policy model. It replaces the NaNs by -2s, removes the activity attribute, and explicitly appends the attribute vector of the active node (if present).
- Parameters:
states (Union[List[torchtyping.TensorType[state_dim]], torchtyping.TensorType[batch_size, state_dim]])
- Return type:
torchtyping.TensorType[batch_size, policy_input_dim]
- readable2state(readable)[source]
Converts a human-readable representation of a state into the standard format.
- static find_active(state)[source]
Get index of the (only) active node. Assumes that active node exists (that we are in the middle of a macro step).
- Parameters:
state (torch.Tensor)
- Return type:
int
- static get_n_nodes(state)[source]
Returns the number of nodes in a tree represented by the given state.
- Parameters:
state (torch.Tensor)
- Return type:
int
- get_policy_output_continuous(params)[source]
Defines the structure of the output of the policy model, from which an action is to be determined or sampled. It initializes the output tensor by using the parameters provided in the argument params.
The output of the policy of a Tree environment consists of a discrete and continuous part. The discrete part (first part) corresponds to the discrete actions, while the continuous part (second part) corresponds to the single continuous action, that is the sampling of the threshold of a node classifier.
The latter is modelled by a mixture of Beta distributions. Therefore, the continuous part of the policy output is vector of dimensionality c * 3, where c is the number of components in the mixture (self.components). The three parameters of each component are the following:
the weight of the component in the mixture
the logit(alpha) parameter of the Beta distribution to sample the threshold.
the logit(beta) parameter of the Beta distribution to sample the threshold.
Note: contrary to other environments where there is a need to model a mixture of discrete and continuous distributions (for example to consider the possibility of sampling the EOS action instead of a continuous action), there is no such need here because either the continuous action is the only valid action or it is not valid.
- Parameters:
params (dict)
- Return type:
torchtyping.TensorType[policy_output_dim]
- get_policy_output(params)[source]
Defines the structure of the output of the policy model, from which an action is to be determined or sampled.
- Parameters:
params (dict)
- Return type:
torchtyping.TensorType[policy_output_dim]
- get_mask_invalid_actions_forward(state=None, done=None)[source]
- Returns a list of length the action space with values:
True if the forward action is invalid from the current state.
False otherwise.
For continuous or hybrid environments, this mask corresponds to the discrete part of the action space.
- Parameters:
state (Optional[torch.Tensor])
done (Optional[bool])
- Return type:
List[bool]
- get_mask_invalid_actions_backward_continuous(state=None, done=None, parents_a=None)[source]
Simply appends to the standard “discrete part” of the mask a dummy part corresponding to the continuous part of the policy output so as to match the dimensionality.
- Parameters:
state (Optional[torch.Tensor])
done (Optional[bool])
parents_a (Optional[List])
- Return type:
List
- get_mask_invalid_actions_backward(state=None, done=None, parents_a=None)[source]
- Returns a list of length the action space with values:
True if the backward action is invalid from the current state.
False otherwise.
For continuous or hybrid environments, this mask corresponds to the discrete part of the action space.
The base implementation below should be common to all discrete spaces as it relies on get_parents, which is environment-specific and must be implemented. Continuous environments will probably need to implement its specific version of this method.
- Parameters:
state (Optional[torch.Tensor])
done (Optional[bool])
parents_a (Optional[List])
- Return type:
List
- get_parents(state=None, done=None, action=None)[source]
Determines all parents and actions that lead to state.
In continuous environments, get_parents() should return only the parent from which action leads to state.
- Parameters:
state (list) – Representation of a state
done (bool) – Whether the trajectory is done. If None, done is taken from instance.
action (tuple) – Last action performed
- Returns:
parents (list) – List of parents in state format
actions (list) – List of actions that lead to state for each parent in parents
- Return type:
Tuple[List, List]
- static action2representative_continuous(action)[source]
Replaces the continuous value of a PICK_THRESHOLD action by -1 so that it can be contrasted with the action space and masks.
- Parameters:
action (Tuple)
- Return type:
Tuple
- action2representative(action)[source]
For continuous or hybrid environments, converts a continuous action into its representative in the action space. Discrete actions remain identical, thus fully discrete environments do not need to re-implement this method. Continuous environments should re-implement this method in order to replace continuous actions by their representatives in the action space.
- Parameters:
action (Tuple)
- Return type:
Tuple
- static state2pyg(state, n_features, one_hot=True, add_self_loop=False)[source]
Convert given state into a PyG graph.
- Parameters:
state (torch.Tensor)
n_features (int)
one_hot (bool)
add_self_loop (bool)
- Return type:
torch_geometric.data.Data
- static predict(state, x, *, return_k=False, k=0)[source]
Recursively predict output label given a feature vector x of a single observation.
If return_k is True, will also return the index of the node in which prediction was made.
- Parameters:
state (torch.Tensor)
x (numpy.typing.NDArray)
return_k (bool)
k (int)
- Return type:
Union[int, Tuple[int, int]]
- static plot(state, path=None)[source]
Plot current state of the tree.
- Parameters:
path (Optional[Union[pathlib.Path, str]])
- Return type:
None
- test(samples)[source]
Computes a dictionary of metrics, as described in Tree._compute_scores, for both training and, if available, test data. If self.test_args[‘top_k_trees’] != 0, also plots top n trees and saves them in the log directory.
- Parameters:
samples (Tensor) – Collection of sampled states representing the ensemble.
- Returns:
Dictionary of (metric_name, score) key-value pairs.
- Return type:
dict