gflownet.envs.tree

IMPORTANT: this environment is not up to date.

Classes

NodeType

Encodes two types of nodes present in a tree:

Operator

Operator based on which the decision is made (< or >=).

Status

Status of the node. Every node except the one on which a macro step

Stage

Current stage of the tree, encoded as part of the state.

ActionType

Type of action that will be passed to Tree.step. Refer to Stage for details.

Attribute

Contains indices of individual attributes in a state tensor.

Tree

Module Contents

class gflownet.envs.tree.NodeType[source]

Encodes two types of nodes present in a tree: 0 - condition node (node other than leaf), that stores the information

about on which feature to make the decision, and using what threshold.

1 - classifier node (leaf), that stores the information about class output

that will be predicted once that node is reached.

CONDITION = 0[source]
CLASSIFIER = 1[source]
class gflownet.envs.tree.Operator[source]

Operator based on which the decision is made (< or >=).

We assume the convention of having left child output label = 0 and right child label = 1 if operator is <, and the opposite if operator is >=. That way, during prediction, we can act as if the operator was always the same (and only care about the output label).

LT = 0[source]
GTE = 1[source]
class gflownet.envs.tree.Status[source]

Status of the node. Every node except the one on which a macro step was initiated will be marked as inactive; the node will be marked as active iff the process of its splitting is in progress.

INACTIVE = 0[source]
ACTIVE = 1[source]
class gflownet.envs.tree.Stage[source]

Current stage of the tree, encoded as part of the state. 0 - complete, indicates that there is no macro step initiated, and

the only allowed action is to pick one of the leaves for splitting.

1 - leaf, indicates that a leaf was picked for splitting, and the only

allowed action is picking a feature on which it will be split.

2 - feature, indicates that a feature was picked, and the only allowed

action is picking a threshold for splitting.

3 - threshold, indicates that a threshold was picked, and the only

allowed action is picking an operator.

4 - operator, indicates that operator was picked. The only allowed

action from here is finalizing the splitting process and spawning two new leaves, which should be done automatically, upon which the stage should be changed to complete.

COMPLETE = 0[source]
LEAF = 1[source]
FEATURE = 2[source]
THRESHOLD = 3[source]
OPERATOR = 4[source]
class gflownet.envs.tree.ActionType[source]

Type of action that will be passed to Tree.step. Refer to Stage for details.

PICK_LEAF = 0[source]
PICK_FEATURE = 1[source]
PICK_THRESHOLD = 2[source]
PICK_OPERATOR = 3[source]
class gflownet.envs.tree.Attribute[source]

Contains indices of individual attributes in a state tensor.

Types of attributes defining each node of the tree:

0 - node type (condition or classifier), 1 - index of the feature used for splitting (condition node only, -1 otherwise), 2 - decision threshold (condition node only, -1 otherwise), 3 - class output (classifier node only, -1 otherwise), in the case of < operator

the left child will have class = 0, and the right child will have class = 1; the opposite for the >= operator,

4 - whether the node has active status (1 if node was picked and the macro step

didn’t finish yet, 0 otherwise).

TYPE = 0[source]
FEATURE = 1[source]
THRESHOLD = 2[source]
CLASS = 3[source]
ACTIVE = 4[source]
N = 5[source]
class gflownet.envs.tree.Tree(X_train=None, y_train=None, X_test=None, y_test=None, data_path=None, scale_data=True, max_depth=10, continuous=True, n_thresholds=9, threshold_components=1, beta_params_min=1.0, beta_params_max=2.0, fixed_distr_params={'beta_alpha': 2.0, 'beta_beta': 5.0}, random_distr_params={'beta_alpha': 1.0, 'beta_beta': 1.0}, policy_format='mlp', test_args={'top_k_trees': 0}, **kwargs)[source]

Bases: gflownet.envs.base.GFlowNetEnv

Parameters:
  • X_train (Optional[numpy.typing.NDArray])

  • y_train (Optional[numpy.typing.NDArray])

  • X_test (Optional[numpy.typing.NDArray])

  • y_test (Optional[numpy.typing.NDArray])

  • data_path (Optional[str])

  • scale_data (bool)

  • max_depth (int)

  • continuous (bool)

  • n_thresholds (Optional[int])

  • threshold_components (int)

  • beta_params_min (float)

  • beta_params_max (float)

  • fixed_distr_params (dict)

  • random_distr_params (dict)

  • policy_format (str)

  • test_args (dict)

X_train

Train dataset, with dimensionality (n_observations, n_features). It may be None if a data set is provided via data_path.

Type:

np.array

y_train[source]

Train labels, with dimensionality (n_observations,). It may be None if a data set is provided via data_path.

Type:

np.array

X_test

Test dataset, with dimensionality (n_observations, n_features). It may be None if a data set is provided via data_path, or if you don’t want to perform test set evaluation.

Type:

np.array

y_train[source]

Test labels, with dimensionality (n_observations,). It may be None if a data set is provided via data_path, or if you don’t want to perform test set evaluation.

Type:

np.array

data_path

A path to a data set, with the following options: - *.pkl: Pickled dict with X_train, y_train, and (optional) X_test and y_test

variables.

  • *.csv: CSV containing an optional ‘Split’ column in the last place, containing ‘train’ and ‘test’ values, and M remaining columns, where the first (M - 1) columns will be taken to construct the input X, and M-th column will be the target y.

Ignored if X_train and y_train are not None.

Type:

str

scale_data

Whether to perform min-max scaling on the provided data (to a [0; 1] range).

Type:

bool

max_depth[source]

Maximum depth of a tree.

Type:

int

continuous[source]

Whether the environment should operate in a continuous mode (in which distribution parameters are predicted for the threshold) or the discrete mode (in which there is a discrete set of possible thresholds to choose from).

Type:

bool

n_thresholds

Number of uniformly distributed thresholds in a (0; 1) range that will be used in the discrete mode. Ignored if continuous is True.

Type:

int

policy_format

Type of policy that will be used with the environment, either ‘mlp’ or ‘gnn’. Influences which state2policy functions will be used.

Type:

str

threshold_components

The number of mixture components that will be used for sampling the threshold.

Type:

int

y_train[source]
n_features[source]
max_depth = 10[source]
continuous = True[source]
test_args[source]
components = 1[source]
beta_params_min = 1.0[source]
beta_params_max = 2.0[source]
n_nodes = 1023[source]
source[source]
default_class[source]
eos[source]
get_action_space()[source]
Actions are a tuple containing:
  1. action type:

    0 - pick leaf to split, 1 - pick feature, 2 - pick threshold, 3 - pick operator,

  2. node index,

  3. action value, depending on the action type:

    pick leaf: current class output, pick feature: feature index, pick threshold: threshold value, pick operator: operator index.

Return type:

List[Tuple[int, int, int]]

step(action, skip_mask_check=False)[source]

Executes step given an action.

Parameters:
  • action (tuple) – Action to be executed. See: self.get_action_space()

  • skip_mask_check (bool) – If True, skip computing forward mask of invalid actions to check if the action is valid.

Returns:

  • self.state (list) – The sequence after executing the action

  • action (tuple) – Action executed

  • valid (bool) – False, if the action is not allowed for the current state.

Return type:

Tuple[List[int], Tuple[int, int, Union[int, float]], bool]

step_backwards(action, skip_mask_check=False)[source]

Executes a backward step given an action.

Parameters:
  • action (tuple) – Action from the action space.

  • skip_mask_check (bool) – If True, skip computing forward mask of invalid actions to check if the action is valid.

Returns:

  • self.state (list) – The state after executing the action.

  • action (int) – Given action.

  • valid (bool) – False, if the action is not allowed for the current state.

Return type:

Tuple[List[int], Tuple[int], bool]

set_state(state, done=False)[source]

Sets the state and done. If done is True but incompatible with state (Stage is not COMPLETE), then force done False and print warning.

Parameters:
  • state (List)

  • done (Optional[bool])

sample_actions_batch_continuous(policy_outputs, mask=None, states_from=None, is_backward=False, random_action_prob=0.0, temperature_logits=1.0)[source]

Samples a batch of actions from a batch of policy outputs in the continuous mode.

Parameters:
  • policy_outputs (torchtyping.TensorType[Tree.sample_actions_batch_continuous.n_states, policy_output_dim])

  • mask (Optional[torchtyping.TensorType[Tree.sample_actions_batch_continuous.n_states, policy_output_dim]])

  • states_from (Optional[List])

  • is_backward (Optional[bool])

  • random_action_prob (Optional[float])

  • temperature_logits (Optional[float])

Return type:

Tuple[List[Tuple], torchtyping.TensorType[Tree.sample_actions_batch_continuous.n_states]]

sample_actions_batch(policy_outputs, mask=None, states_from=None, is_backward=False, random_action_prob=0.0, temperature_logits=1.0)[source]

Samples a batch of actions from a batch of policy outputs.

Parameters:
  • policy_outputs (torchtyping.TensorType[n_states, policy_output_dim])

  • mask (Optional[torchtyping.TensorType[n_states, policy_output_dim]])

  • states_from (Optional[List])

  • is_backward (Optional[bool])

  • random_action_prob (Optional[float])

  • temperature_logits (Optional[float])

Return type:

Tuple[List[Tuple], torchtyping.TensorType[n_states]]

get_logprobs_continuous(policy_outputs, actions, mask=None, states_from=None, is_backward=False)[source]

Computes log probabilities of actions given policy outputs and actions.

Parameters:
  • policy_outputs (torchtyping.TensorType[Tree.get_logprobs_continuous.n_states, policy_output_dim])

  • actions (Union[List, torchtyping.TensorType[Tree.get_logprobs_continuous.n_states, action_dim]])

  • mask (torchtyping.TensorType[Tree.get_logprobs_continuous.n_states, 1])

  • states_from (Optional[List])

  • is_backward (bool)

Return type:

torchtyping.TensorType[batch_size]

get_logprobs(policy_outputs, actions, mask=None, states_from=None, is_backward=False)[source]

Computes log probabilities of actions given policy outputs and actions.

Parameters:
  • policy_outputs (torchtyping.TensorType[n_states, policy_output_dim])

  • actions (Union[List, torchtyping.TensorType[n_states, action_dim]])

  • mask (torchtyping.TensorType[n_states, 1])

  • states_from (Optional[List])

  • is_backward (bool)

Return type:

torchtyping.TensorType[batch_size]

states2policy_mlp(states)[source]

Prepares a batch of states in torch “GFlowNet format” for an MLP policy model. It replaces the NaNs by -2s, removes the activity attribute, and explicitly appends the attribute vector of the active node (if present).

Parameters:

states (Union[List[torchtyping.TensorType[state_dim]], torchtyping.TensorType[batch_size, state_dim]])

Return type:

torchtyping.TensorType[batch_size, policy_input_dim]

state2readable(state=None)[source]

Converts a state into human-readable representation.

readable2state(readable)[source]

Converts a human-readable representation of a state into the standard format.

static find_active(state)[source]

Get index of the (only) active node. Assumes that active node exists (that we are in the middle of a macro step).

Parameters:

state (torch.Tensor)

Return type:

int

static get_n_nodes(state)[source]

Returns the number of nodes in a tree represented by the given state.

Parameters:

state (torch.Tensor)

Return type:

int

get_policy_output_continuous(params)[source]

Defines the structure of the output of the policy model, from which an action is to be determined or sampled. It initializes the output tensor by using the parameters provided in the argument params.

The output of the policy of a Tree environment consists of a discrete and continuous part. The discrete part (first part) corresponds to the discrete actions, while the continuous part (second part) corresponds to the single continuous action, that is the sampling of the threshold of a node classifier.

The latter is modelled by a mixture of Beta distributions. Therefore, the continuous part of the policy output is vector of dimensionality c * 3, where c is the number of components in the mixture (self.components). The three parameters of each component are the following:

  1. the weight of the component in the mixture

  2. the logit(alpha) parameter of the Beta distribution to sample the threshold.

  3. the logit(beta) parameter of the Beta distribution to sample the threshold.

Note: contrary to other environments where there is a need to model a mixture of discrete and continuous distributions (for example to consider the possibility of sampling the EOS action instead of a continuous action), there is no such need here because either the continuous action is the only valid action or it is not valid.

Parameters:

params (dict)

Return type:

torchtyping.TensorType[policy_output_dim]

get_policy_output(params)[source]

Defines the structure of the output of the policy model, from which an action is to be determined or sampled.

Parameters:

params (dict)

Return type:

torchtyping.TensorType[policy_output_dim]

get_mask_invalid_actions_forward(state=None, done=None)[source]
Returns a list of length the action space with values:
  • True if the forward action is invalid from the current state.

  • False otherwise.

For continuous or hybrid environments, this mask corresponds to the discrete part of the action space.

Parameters:
Return type:

List[bool]

get_mask_invalid_actions_backward_continuous(state=None, done=None, parents_a=None)[source]

Simply appends to the standard “discrete part” of the mask a dummy part corresponding to the continuous part of the policy output so as to match the dimensionality.

Parameters:
  • state (Optional[torch.Tensor])

  • done (Optional[bool])

  • parents_a (Optional[List])

Return type:

List

get_mask_invalid_actions_backward(state=None, done=None, parents_a=None)[source]
Returns a list of length the action space with values:
  • True if the backward action is invalid from the current state.

  • False otherwise.

For continuous or hybrid environments, this mask corresponds to the discrete part of the action space.

The base implementation below should be common to all discrete spaces as it relies on get_parents, which is environment-specific and must be implemented. Continuous environments will probably need to implement its specific version of this method.

Parameters:
  • state (Optional[torch.Tensor])

  • done (Optional[bool])

  • parents_a (Optional[List])

Return type:

List

get_parents(state=None, done=None, action=None)[source]

Determines all parents and actions that lead to state.

In continuous environments, get_parents() should return only the parent from which action leads to state.

Parameters:
  • state (list) – Representation of a state

  • done (bool) – Whether the trajectory is done. If None, done is taken from instance.

  • action (tuple) – Last action performed

Returns:

  • parents (list) – List of parents in state format

  • actions (list) – List of actions that lead to state for each parent in parents

Return type:

Tuple[List, List]

static action2representative_continuous(action)[source]

Replaces the continuous value of a PICK_THRESHOLD action by -1 so that it can be contrasted with the action space and masks.

Parameters:

action (Tuple)

Return type:

Tuple

action2representative(action)[source]

For continuous or hybrid environments, converts a continuous action into its representative in the action space. Discrete actions remain identical, thus fully discrete environments do not need to re-implement this method. Continuous environments should re-implement this method in order to replace continuous actions by their representatives in the action space.

Parameters:

action (Tuple)

Return type:

Tuple

get_pyg_input_dim()[source]
Return type:

int

static state2pyg(state, n_features, one_hot=True, add_self_loop=False)[source]

Convert given state into a PyG graph.

Parameters:
  • state (torch.Tensor)

  • n_features (int)

  • one_hot (bool)

  • add_self_loop (bool)

Return type:

torch_geometric.data.Data

static predict(state, x, *, return_k=False, k=0)[source]

Recursively predict output label given a feature vector x of a single observation.

If return_k is True, will also return the index of the node in which prediction was made.

Parameters:
  • state (torch.Tensor)

  • x (numpy.typing.NDArray)

  • return_k (bool)

  • k (int)

Return type:

Union[int, Tuple[int, int]]

static plot(state, path=None)[source]

Plot current state of the tree.

Parameters:

path (Optional[Union[pathlib.Path, str]])

Return type:

None

test(samples)[source]

Computes a dictionary of metrics, as described in Tree._compute_scores, for both training and, if available, test data. If self.test_args[‘top_k_trees’] != 0, also plots top n trees and saves them in the log directory.

Parameters:

samples (Tensor) – Collection of sampled states representing the ensemble.

Returns:

Dictionary of (metric_name, score) key-value pairs.

Return type:

dict