gflownet.envs.tree ================== .. py:module:: gflownet.envs.tree .. autoapi-nested-parse:: IMPORTANT: this environment is not up to date. Classes ------- .. autoapisummary:: gflownet.envs.tree.NodeType gflownet.envs.tree.Operator gflownet.envs.tree.Status gflownet.envs.tree.Stage gflownet.envs.tree.ActionType gflownet.envs.tree.Attribute gflownet.envs.tree.Tree Module Contents --------------- .. py:class:: NodeType Encodes two types of nodes present in a tree: 0 - condition node (node other than leaf), that stores the information about on which feature to make the decision, and using what threshold. 1 - classifier node (leaf), that stores the information about class output that will be predicted once that node is reached. .. py:attribute:: CONDITION :value: 0 .. py:attribute:: CLASSIFIER :value: 1 .. py:class:: Operator Operator based on which the decision is made (< or >=). We assume the convention of having left child output label = 0 and right child label = 1 if operator is <, and the opposite if operator is >=. That way, during prediction, we can act as if the operator was always the same (and only care about the output label). .. py:attribute:: LT :value: 0 .. py:attribute:: GTE :value: 1 .. py:class:: Status Status of the node. Every node except the one on which a macro step was initiated will be marked as inactive; the node will be marked as active iff the process of its splitting is in progress. .. py:attribute:: INACTIVE :value: 0 .. py:attribute:: ACTIVE :value: 1 .. py:class:: Stage Current stage of the tree, encoded as part of the state. 0 - complete, indicates that there is no macro step initiated, and the only allowed action is to pick one of the leaves for splitting. 1 - leaf, indicates that a leaf was picked for splitting, and the only allowed action is picking a feature on which it will be split. 2 - feature, indicates that a feature was picked, and the only allowed action is picking a threshold for splitting. 3 - threshold, indicates that a threshold was picked, and the only allowed action is picking an operator. 4 - operator, indicates that operator was picked. The only allowed action from here is finalizing the splitting process and spawning two new leaves, which should be done automatically, upon which the stage should be changed to complete. .. py:attribute:: COMPLETE :value: 0 .. py:attribute:: LEAF :value: 1 .. py:attribute:: FEATURE :value: 2 .. py:attribute:: THRESHOLD :value: 3 .. py:attribute:: OPERATOR :value: 4 .. py:class:: ActionType Type of action that will be passed to Tree.step. Refer to Stage for details. .. py:attribute:: PICK_LEAF :value: 0 .. py:attribute:: PICK_FEATURE :value: 1 .. py:attribute:: PICK_THRESHOLD :value: 2 .. py:attribute:: PICK_OPERATOR :value: 3 .. py:class:: Attribute Contains indices of individual attributes in a state tensor. Types of attributes defining each node of the tree: 0 - node type (condition or classifier), 1 - index of the feature used for splitting (condition node only, -1 otherwise), 2 - decision threshold (condition node only, -1 otherwise), 3 - class output (classifier node only, -1 otherwise), in the case of < operator the left child will have class = 0, and the right child will have class = 1; the opposite for the >= operator, 4 - whether the node has active status (1 if node was picked and the macro step didn't finish yet, 0 otherwise). .. py:attribute:: TYPE :value: 0 .. py:attribute:: FEATURE :value: 1 .. py:attribute:: THRESHOLD :value: 2 .. py:attribute:: CLASS :value: 3 .. py:attribute:: ACTIVE :value: 4 .. py:attribute:: N :value: 5 .. py:class:: Tree(X_train = None, y_train = None, X_test = None, y_test = None, data_path = None, scale_data = True, max_depth = 10, continuous = True, n_thresholds = 9, threshold_components = 1, beta_params_min = 1.0, beta_params_max = 2.0, fixed_distr_params = {'beta_alpha': 2.0, 'beta_beta': 5.0}, random_distr_params = {'beta_alpha': 1.0, 'beta_beta': 1.0}, policy_format = 'mlp', test_args = {'top_k_trees': 0}, **kwargs) Bases: :py:obj:`gflownet.envs.base.GFlowNetEnv` .. attribute:: X_train Train dataset, with dimensionality (n_observations, n_features). It may be None if a data set is provided via data_path. :type: np.array .. attribute:: y_train Train labels, with dimensionality (n_observations,). It may be None if a data set is provided via data_path. :type: np.array .. attribute:: X_test Test dataset, with dimensionality (n_observations, n_features). It may be None if a data set is provided via data_path, or if you don't want to perform test set evaluation. :type: np.array .. attribute:: y_train Test labels, with dimensionality (n_observations,). It may be None if a data set is provided via data_path, or if you don't want to perform test set evaluation. :type: np.array .. attribute:: data_path A path to a data set, with the following options: - *.pkl: Pickled dict with X_train, y_train, and (optional) X_test and y_test variables. - *.csv: CSV containing an optional 'Split' column in the last place, containing 'train' and 'test' values, and M remaining columns, where the first (M - 1) columns will be taken to construct the input X, and M-th column will be the target y. Ignored if X_train and y_train are not None. :type: str .. attribute:: scale_data Whether to perform min-max scaling on the provided data (to a [0; 1] range). :type: bool .. attribute:: max_depth Maximum depth of a tree. :type: int .. attribute:: continuous Whether the environment should operate in a continuous mode (in which distribution parameters are predicted for the threshold) or the discrete mode (in which there is a discrete set of possible thresholds to choose from). :type: bool .. attribute:: n_thresholds Number of uniformly distributed thresholds in a (0; 1) range that will be used in the discrete mode. Ignored if continuous is True. :type: int .. attribute:: policy_format Type of policy that will be used with the environment, either 'mlp' or 'gnn'. Influences which state2policy functions will be used. :type: str .. attribute:: threshold_components The number of mixture components that will be used for sampling the threshold. :type: int .. py:attribute:: y_train .. py:attribute:: n_features .. py:attribute:: max_depth :value: 10 .. py:attribute:: continuous :value: True .. py:attribute:: test_args .. py:attribute:: components :value: 1 .. py:attribute:: beta_params_min :value: 1.0 .. py:attribute:: beta_params_max :value: 2.0 .. py:attribute:: n_nodes :value: 1023 .. py:attribute:: source .. py:attribute:: default_class .. py:attribute:: eos .. py:method:: get_action_space() Actions are a tuple containing: 1) action type: 0 - pick leaf to split, 1 - pick feature, 2 - pick threshold, 3 - pick operator, 2) node index, 3) action value, depending on the action type: pick leaf: current class output, pick feature: feature index, pick threshold: threshold value, pick operator: operator index. .. py:method:: step(action, skip_mask_check = False) Executes step given an action. :param action: Action to be executed. See: self.get_action_space() :type action: tuple :param skip_mask_check: If True, skip computing forward mask of invalid actions to check if the action is valid. :type skip_mask_check: bool :returns: * **self.state** (*list*) -- The sequence after executing the action * **action** (*tuple*) -- Action executed * **valid** (*bool*) -- False, if the action is not allowed for the current state. .. py:method:: step_backwards(action, skip_mask_check = False) Executes a backward step given an action. :param action: Action from the action space. :type action: tuple :param skip_mask_check: If True, skip computing forward mask of invalid actions to check if the action is valid. :type skip_mask_check: bool :returns: * **self.state** (*list*) -- The state after executing the action. * **action** (*int*) -- Given action. * **valid** (*bool*) -- False, if the action is not allowed for the current state. .. py:method:: set_state(state, done = False) Sets the state and done. If done is True but incompatible with state (Stage is not COMPLETE), then force done False and print warning. .. py:method:: sample_actions_batch_continuous(policy_outputs, mask = None, states_from = None, is_backward = False, random_action_prob = 0.0, temperature_logits = 1.0) Samples a batch of actions from a batch of policy outputs in the continuous mode. .. py:method:: sample_actions_batch(policy_outputs, mask = None, states_from = None, is_backward = False, random_action_prob = 0.0, temperature_logits = 1.0) Samples a batch of actions from a batch of policy outputs. .. py:method:: get_logprobs_continuous(policy_outputs, actions, mask = None, states_from = None, is_backward = False) Computes log probabilities of actions given policy outputs and actions. .. py:method:: get_logprobs(policy_outputs, actions, mask = None, states_from = None, is_backward = False) Computes log probabilities of actions given policy outputs and actions. .. py:method:: states2policy_mlp(states) Prepares a batch of states in torch "GFlowNet format" for an MLP policy model. It replaces the NaNs by -2s, removes the activity attribute, and explicitly appends the attribute vector of the active node (if present). .. py:method:: state2readable(state=None) Converts a state into human-readable representation. .. py:method:: readable2state(readable) Converts a human-readable representation of a state into the standard format. .. py:method:: find_active(state) :staticmethod: Get index of the (only) active node. Assumes that active node exists (that we are in the middle of a macro step). .. py:method:: get_n_nodes(state) :staticmethod: Returns the number of nodes in a tree represented by the given state. .. py:method:: get_policy_output_continuous(params) Defines the structure of the output of the policy model, from which an action is to be determined or sampled. It initializes the output tensor by using the parameters provided in the argument params. The output of the policy of a Tree environment consists of a discrete and continuous part. The discrete part (first part) corresponds to the discrete actions, while the continuous part (second part) corresponds to the single continuous action, that is the sampling of the threshold of a node classifier. The latter is modelled by a mixture of Beta distributions. Therefore, the continuous part of the policy output is vector of dimensionality c * 3, where c is the number of components in the mixture (self.components). The three parameters of each component are the following: 1) the weight of the component in the mixture 2) the logit(alpha) parameter of the Beta distribution to sample the threshold. 3) the logit(beta) parameter of the Beta distribution to sample the threshold. Note: contrary to other environments where there is a need to model a mixture of discrete and continuous distributions (for example to consider the possibility of sampling the EOS action instead of a continuous action), there is no such need here because either the continuous action is the only valid action or it is not valid. .. py:method:: get_policy_output(params) Defines the structure of the output of the policy model, from which an action is to be determined or sampled. .. py:method:: get_mask_invalid_actions_forward(state = None, done = None) Returns a list of length the action space with values: - True if the forward action is invalid from the current state. - False otherwise. For continuous or hybrid environments, this mask corresponds to the discrete part of the action space. .. py:method:: get_mask_invalid_actions_backward_continuous(state = None, done = None, parents_a = None) Simply appends to the standard "discrete part" of the mask a dummy part corresponding to the continuous part of the policy output so as to match the dimensionality. .. py:method:: get_mask_invalid_actions_backward(state = None, done = None, parents_a = None) Returns a list of length the action space with values: - True if the backward action is invalid from the current state. - False otherwise. For continuous or hybrid environments, this mask corresponds to the discrete part of the action space. The base implementation below should be common to all discrete spaces as it relies on get_parents, which is environment-specific and must be implemented. Continuous environments will probably need to implement its specific version of this method. .. py:method:: get_parents(state = None, done = None, action = None) Determines all parents and actions that lead to state. In continuous environments, get_parents() should return only the parent from which action leads to state. :param state: Representation of a state :type state: list :param done: Whether the trajectory is done. If None, done is taken from instance. :type done: bool :param action: Last action performed :type action: tuple :returns: * **parents** (*list*) -- List of parents in state format * **actions** (*list*) -- List of actions that lead to state for each parent in parents .. py:method:: action2representative_continuous(action) :staticmethod: Replaces the continuous value of a PICK_THRESHOLD action by -1 so that it can be contrasted with the action space and masks. .. py:method:: action2representative(action) For continuous or hybrid environments, converts a continuous action into its representative in the action space. Discrete actions remain identical, thus fully discrete environments do not need to re-implement this method. Continuous environments should re-implement this method in order to replace continuous actions by their representatives in the action space. .. py:method:: get_pyg_input_dim() .. py:method:: state2pyg(state, n_features, one_hot = True, add_self_loop = False) :staticmethod: Convert given state into a PyG graph. .. py:method:: predict(state, x, *, return_k = False, k = 0) :staticmethod: Recursively predict output label given a feature vector x of a single observation. If return_k is True, will also return the index of the node in which prediction was made. .. py:method:: plot(state, path = None) :staticmethod: Plot current state of the tree. .. py:method:: test(samples) Computes a dictionary of metrics, as described in Tree._compute_scores, for both training and, if available, test data. If self.test_args['top_k_trees'] != 0, also plots top n trees and saves them in the log directory. :param samples: Collection of sampled states representing the ensemble. :type samples: Tensor :returns: *Dictionary of (metric_name, score) key-value pairs.*