gflownet.utils.batch ==================== .. py:module:: gflownet.utils.batch Classes ------- .. autoapisummary:: gflownet.utils.batch.Batch Functions --------- .. autoapisummary:: gflownet.utils.batch.compute_logprobs_trajectories Module Contents --------------- .. py:class:: Batch(env = None, proxy = None, device = 'cpu', float_type = 32, collect_forwards_masks=False, collect_backwards_masks=False) :param env: An instance of the environment that will be used to form the batch. :type env: GFlowNetEnv :param proxy: An instance of a GFlowNet proxy that will be used to compute proxy values and rewards. :type proxy: Proxy :param device: torch.device or string indicating the device to use ("cpu" or "cuda") :type device: str or torch.device :param float_type: One of float torch.dtype or an int indicating the float precision (16, 32 or 64). :type float_type: torch.dtype or int .. py:attribute:: device :value: 'cpu' .. py:attribute:: float :value: 32 .. py:attribute:: proxy :value: None .. py:attribute:: size :value: 0 .. py:attribute:: envs .. py:attribute:: trajectories .. py:attribute:: is_backward .. py:attribute:: traj_indices :value: [] .. py:attribute:: state_indices :value: [] .. py:attribute:: states :value: [] .. py:attribute:: actions :value: [] .. py:attribute:: logprobs_forward :value: [] .. py:attribute:: logprobs_backward :value: [] .. py:attribute:: logprobs_forward_avail :value: [] .. py:attribute:: logprobs_backward_avail :value: [] .. py:attribute:: done :value: [] .. py:attribute:: masks_invalid_actions_forward :value: [] .. py:attribute:: masks_invalid_actions_backward :value: [] .. py:attribute:: parents :value: [] .. py:attribute:: parents_all :value: [] .. py:attribute:: parents_actions_all :value: [] .. py:attribute:: n_actions :value: [] .. py:attribute:: states_policy :value: None .. py:attribute:: parents_policy :value: None .. py:attribute:: collect_forwards_masks :value: False .. py:attribute:: collect_backwards_masks :value: False .. py:method:: __len__() .. py:method:: batch_idx_to_traj_state_idx(batch_idx) .. py:method:: traj_idx_to_batch_indices(traj_idx) .. py:method:: traj_state_idx_to_batch_idx(traj_idx, state_idx) .. py:method:: traj_idx_action_idx_to_batch_idx(traj_idx, action_idx, backward) .. py:method:: idx2state_idx(idx) .. py:method:: rewards_available(log = False) Returns True if the (log)rewards are available. :param log: If True, check self._logrewards_available. Otherwise (default), check self._rewards_available. :type log: bool :returns: *bool* -- True if the (log)rewards are available, False otherwise. .. py:method:: rewards_parents_available(log = False) Returns True if the (log)rewards of the parents are available. :param log: If True, check self._logrewards_parents_available. Otherwise (default), check self._rewards_parents_available. :type log: bool :returns: *bool* -- True if the (log)rewards of the parents are available, False otherwise. .. py:method:: rewards_source_available(log = False) Returns True if the (log)rewards of the source are available. :param log: If True, check self._logrewards_source_available. Otherwise (default), check self._rewards_source_available. :type log: bool :returns: *bool* -- True if the (log)rewards of the source are available, False otherwise. .. py:method:: set_env(env) Sets the generic environment passed as an argument and initializes the environment-dependent properties. .. py:method:: set_proxy(proxy) Sets the proxy, used to compute rewards from a batch of states. .. py:method:: add_to_batch(envs, actions, logprobs, logprobs_rev, valids, backward = False, train = True) Adds information from a list of environments and actions to the batch after performing steps in the envs. If train is False, only the variables of terminating states are stored. :param envs: A list of environments (GFlowNetEnv). :type envs: list :param actions: A list of actions attempted or performed on the envs. :type actions: list :param logprobs: Log probabilities corresponding to the actions or None. :type logprobs: torch.tensor or list of None :param logprobs_rev: Log probabilities corresponding to the transitions in the opposite direction to the sampling direction, from the current states but with the actions added in the previous step. It may be a list of None. :type logprobs_rev: torch.tensor or list of None :param valids: A list of boolean values indicated whether the actions were valid. :type valids: list :param backward: A boolean value indicating whether the action was sampled backward (False by default). If True, the behavior is slightly different so as to match what is stored in forward sampling: - If it is the first state in the trajectory (action from a done state/env), then done is stored as True, instead of taking env.done which will be False after having performed the step. - If it is not the first state in the trajectory, the stored state will be the previous one in the trajectory, to match the state-action stored in forward sampling and the convention that the source state is not stored, but the terminating state is repeated with action eos. :type backward: bool :param train: A boolean value indicating whether the data to add to the batch will be used for training. Optional, default is True. :type train: bool .. py:method:: get_n_trajectories() Returns the number of trajectories in the batch. :returns: *The number of trajectories in the batch (int).* .. py:method:: get_unique_trajectory_indices() Returns the unique trajectory indices as the keys of self.trajectories, which is an OrderedDict, as a list. .. py:method:: get_trajectory_indices(consecutive = False, return_mapping_dict = False) Returns the trajectory index of all elements in the batch as a long int torch tensor. :param consecutive: If True, the trajectory indices are mapped to consecutive indices starting from 0, in the order of the OrderedDict self.trajectory.keys(). If False (default), the trajectory indices are returned as they are. :type consecutive: bool :param return_mapping_dict: If True, the dictionary mapping actual_index: consecutive_index is returned as a second argument. Ignored if consecutive is False. :type return_mapping_dict: bool :returns: * **traj_indices** (*torch.tensor*) -- self.traj_indices as a long int torch tensor. * **traj_index_to_consecutive_dict** (*dict*) -- A dictionary mapping the actual trajectory indices in the Batch to the consecutive indices. Ommited if return_mapping_dict is False (default). .. py:method:: get_state_indices() Returns the state index of all elements in the batch as a long int torch tensor. :returns: **state_indices** (*torch.tensor*) -- self.state_indices as a long int torch tensor. .. py:method:: get_states(policy = False, proxy = False, force_recompute = False, indices = None) Returns all the states in the batch. The states are returned in "policy format" if policy is True, in "proxy format" if proxy is True and otherwise they are returned in "GFlowNet" format by default. An error is raised if both policy and proxy are True. :param policy: If True, the policy format of the states is returned and self.states_policy is updated if not available yet or if force_recompute is True. :type policy: bool :param proxy: If True, the proxy format of the states is returned. States in proxy format are not stored. :type proxy: bool :param force_recompute: If True, the policy states are recomputed even if they are available. Ignored if policy is False. :type force_recompute: bool :param indices: 1-dimensional sequence of batch indices for selecting states, optional. If None (default), all the states will be returned. :type indices: list, tuple, tensor or np.ndarray :returns: * **self.states or self.states_policy or self.states2proxy(self.states)** (*list or*) * *torch.tensor or ndarray* -- The set of states in the selected format with the selected indices. If indices is None, all states of the batch are returned. .. py:method:: states2policy(states = None, traj_indices = None) Converts states from a list of states in GFlowNet format to a tensor of states in policy format. :param states: List of states in GFlowNet format. :type states: list :param traj_indices: Ids indicating which env corresponds to each state in states. It is only used if the environments are conditional to call state2policy from the right environment. Ignored if self.conditional is False. :type traj_indices: list or torch.tensor :returns: **states** (*torch.tensor*) -- States in policy format. .. py:method:: states2proxy(states = None, traj_indices = None) Converts states from a list of states in GFlowNet format to a tensor of states in proxy format. Note that the implementatiuon of this method differs from Batch.states2policy() because the latter always returns torch.tensors. The output of the present method can also be numpy arrays or Python lists, depending on the proxy. :param states: List of states in GFlowNet format. :type states: list :param traj_indices: Ids indicating which env corresponds to each state in states. It is only used if the environments are conditional to call state2proxy from the right environment. Ignored if self.conditional is False. :type traj_indices: list or torch.tensor :returns: **states** (*torch.tensor or ndarray or list*) -- States in proxy format. .. py:method:: get_actions(indices = None) Returns the actions in the batch. :param indices: 1-dimentional sequence of bacth indecies for selecting actions, optional. If None (default), all the actions will be returned. :type indices: list, tuple, tensor or np.ndarray :returns: *list* -- The list of actions in the batch with selected indices. If indices is None, all the actions will be returned. .. py:method:: get_logprobs(backward = False) Returns the logprobs in the batch as a float tensor. If there is any None values in self.logprobs, the list cannot be converted to a tensor. This exception is caught and None is returned, as a signal that the logprobs are not available. :param backward: Whether the requested logprobs are of backward transitions. :type backward: bool :returns: * **logprobs** (*TensorType["n_states"]*) -- Tensor of logprobs * **valids** (*TensorType["n_states"]*) -- Boolean tensor with flags indicating whether the correspondig logprobs are valid. The flag is False if the corresponding logprob value is a zero / None placefolder. .. py:method:: get_done() Returns the list of done flags as a boolean tensor. .. py:method:: get_parents(policy = False, force_recompute = False, indices = None) Returns the parent (single parent for each state) of all states in the batch. The parents are computed, obtaining all necessary components, if they are not readily available. Missing components and newly computed components are added to the batch (self.component is set). The parents are returned in "policy format" if policy is True, otherwise they are returned in "GFlowNet" format (default). :param policy: If True, the policy format of parents is returned. Otherwise, the GFlowNet format is returned. :type policy: bool :param force_recompute: If True, the parents are recomputed even if they are available. :type force_recompute: bool :param indices: 1-dimentional sequence of bacth indecies for selecting parents, optional. If None (default), the parents of all states in the batch will be returned. :type indices: list, tuple, tensor or np.ndarray :returns: **self.parents or self.parents_policy** (*torch.tensor*) -- The parent of states selected by indices. If indices is None, the parents of all states in the batch are returned. .. py:method:: get_parents_indices(indices=None) Returns the indices of the parents of the states in the batch. Each i-th item in the returned list contains the indices in self.states that contains the parent of self.states[i], if it is present there. If a parent is not present in self.states (because it is the source), the indices is -1. :returns: * *self.parents_indices* -- The indices in self.states of the parents of self.states. * **indices** (*list, tuple, tensor or np.ndarray*) -- 1-dimentional sequence of bacth indecies for selecting parents indices, optional. If None (default), the indices of parents of all states in the batch will be returned. .. py:method:: get_parents_all(policy = False, force_recompute = False) Returns the whole set of parents, their corresponding actions and indices of all states in the batch. If the parents are not available (self._parents_all_available is False) or if force_recompute is True, then self._compute_parents_all() is called to compute the required components. The parents are returned in "policy format" if policy is True, otherwise they are returned in "GFlowNet" format (default). :param policy: If True, the policy format of parents is returned. Otherwise, the GFlowNet format is returned. :type policy: bool :param force_recompute: If True, the parents are recomputed even if they are available. :type force_recompute: bool :returns: * **self.parents_all or self.parents_all_policy** (*list or torch.tensor*) -- The whole set of parents of all states in the batch. * **self.parents_actions_all** (*torch.tensor*) -- The actions corresponding to each parent in self.parents_all or self.parents_all_policy, linking them to the corresponding state in the trajectory. * **self.parents_all_indices** (*torch.tensor*) -- The state index corresponding to each parent in self.parents_all or self.parents_all_policy, linking them to the corresponding state in the batch. .. py:method:: get_masks_forward(of_parents = False, force_recompute = False, indices = None) Returns the forward mask of invalid actions of all states in the batch or of their parent in the trajectory if of_parents is True. The masks are computed via self._compute_masks_forward if they are not available or if force_recompute is True. :param of_parents: If True, the returned masks will correspond to the parents of the states, instead of to the states (default). :type of_parents: bool :param force_recompute: If True, the masks are recomputed even if they are available. :type force_recompute: bool :param indices: 1-dimentional sequence of bacth indecies for selecting masks, optional. If None (default), the masks of all states in the batch will be returned. :type indices: list, tuple, tensor or np.ndarray :returns: **self.masks_invalid_actions_forward** (*torch.tensor*) -- The forward mask of the selected states (by indices). If indices is None, the masks of all states in the batch will be returned. .. py:method:: get_masks_backward(force_recompute = False, indices = None) Returns the backward mask of invalid actions of all states in the batch. The masks are computed via self._compute_masks_backward if they are not available or if force_recompute is True. :param force_recompute: If True, the masks are recomputed even if they are available. :type force_recompute: bool :param indices: 1-dimentional sequence of bacth indecies for selecting masks, optional. If None (default), the masks of all states in the batch will be returned. :type indices: list, tuple, tensor or np.ndarray :returns: **self.masks_invalid_actions_backward** (*torch.tensor*) -- The backward mask of the selected states (by indices). If indices is None, the masks of all states in the batch will be returned. .. py:method:: get_rewards(log = False, force_recompute = False, do_non_terminating = False) Returns the rewards of all states in the batch (including not done). :param log: If True, return the logarithm of the rewards. :type log: bool :param force_recompute: If True, the rewards are recomputed even if they are available. :type force_recompute: bool :param do_non_terminating: If True, return the actual rewards of the non-terminating states. If False, non-terminating states will be assigned reward 0. :type do_non_terminating: bool .. py:method:: get_proxy_values(force_recompute = False, do_non_terminating = False) Returns the proxy values of all states in the batch (including not done). :param force_recompute: If True, the proxy values are recomputed even if they are available. :type force_recompute: bool :param do_non_terminating: If True, return the actual proxy values of the non-terminating states. If False, non-terminating states will be assigned value inf. :type do_non_terminating: bool .. py:method:: get_rewards_parents(log = False) Returns the rewards of all parents in the batch. :param log: If True, return the logarithm of the rewards. :type log: bool :returns: *self.rewards_parents or self.logrewards_parents* -- A tensor containing the rewards of the parents of self.states. .. py:method:: get_rewards_source(log = False) Returns rewards of the corresponding source states for each state in the batch. :param log: If True, return the logarithm of the rewards. :type log: bool :returns: *self.rewards_source or self.logrewards_source* -- A tensor containing the rewards the source states. .. py:method:: get_terminating_states(sort_by = 'insertion', policy = False, proxy = False) Returns the terminating states in the batch, that is all states with done = True. The states will be returned in either GFlowNet format (default), policy (policy = True) or proxy (proxy = True) format. If both policy and proxy are True, it raises an error due to the ambiguity. The returned states may be sorted by order of insertion (sort_by = "insert[ion]", default) or by trajectory index (sort_by = "traj[ectory]". :param sort_by: Indicates how to sort the output: - insert[ion]: sort by order of insertion (states of trajectories that reached the terminating state first come first) - traj[ectory]: sort by trajectory index (the order in the ordered dict self.trajectories) :type sort_by: str :param policy: If True, the policy format of the states is returned. :type policy: bool :param proxy: If True, the proxy format of the states is returned. :type proxy: bool .. py:method:: get_terminating_rewards(sort_by = 'insertion', log = False, force_recompute = False) Returns the reward of the terminating states in the batch, that is all states with done = True. The returned rewards may be sorted by order of insertion (sort_by = "insert[ion]", default) or by trajectory index (sort_by = "traj[ectory]". :param sort_by: Indicates how to sort the output: - insert[ion]: sort by order of insertion (rewards of trajectories that reached the terminating state first come first) - traj[ectory]: sort by trajectory index (the order in the ordered dict self.trajectories) :type sort_by: str :param log: If True, return the logarithm of the rewards. :type log: bool :param force_recompute: If True, the rewards are recomputed even if they are available. :type force_recompute: bool .. py:method:: get_terminating_proxy_values(sort_by = 'insertion', force_recompute = False) Returns the proxy values of the terminating states in the batch, that is all states with done = True. The returned proxy values may be sorted by order of insertion (sort_by = "insert[ion]", default) or by trajectory index (sort_by = "traj[ectory]". :param sort_by: Indicates how to sort the output: - insert[ion]: sort by order of insertion (proxy values of trajectories that reached the terminating state first come first) - traj[ectory]: sort by trajectory index (the order in the ordered dict self.trajectories) :type sort_by: str :param force_recompute: If True, the proxy_values are recomputed even if they are available. :type force_recompute: bool .. py:method:: get_actions_trajectories() Returns the actions corresponding to all trajectories in the batch, sorted by trajectory index (the order in the ordered dict self.trajectories). .. py:method:: get_states_of_trajectory(traj_idx, states = None, traj_indices = None) Returns the states of the trajectory indicated by traj_idx. If states and traj_indices are not None, then these will be the only states and trajectory indices considered. See: states2policy() See: states2proxy() :param traj_idx: Index of the trajectory from which to return the states. :type traj_idx: int :param states: States from the trajectory to consider. :type states: tensor, array or list :param traj_indices: Trajectory indices of the trajectory to consider. :type traj_indices: tensor, array or list :returns: *Tensor, array or list of states of the requested trajectory.* .. py:method:: get_logprobs_of_trajectory(traj_idx, backward = False) Returns the logprobs of the trajectory indicated by traj_idx. :param traj_idx: Index of the trajectory from which to return the states. :type traj_idx: int :param backward: If True, returns backward logprobs, othervise returns forward logprobs :type backward: bool :returns: *A list of logprobs of the actions the requested trajectory.* .. py:method:: merge(batches) Merges the current Batch (self) with the Batch or list of Batches passed as argument. :returns: *self* .. py:method:: is_valid() Performs basic checks on the current state of the batch. :returns: *True if all the checks are valid, False otherwise.* .. py:method:: traj_indices_are_consecutive() Returns True if the trajectory indices start from 0 and are consecutive; False otherwise. .. py:method:: make_indices_consecutive() Updates the trajectory indices as well as the env ids such that they start from 0 and are consecutive. Note that only the trajectory indices are changed, but importantly the order of the main data in the batch is preserved. Examples: - Original indices: 0, 10, 20 - New indices: 0, 1, 2 - Original indices: 1, 5, 3 - New indices: 0, 1, 2 Note: this method is unsued as of September 1st 2023, but is left here for potential future use. .. py:method:: get_item(item, env = None, traj_idx = None, action_idx = None, backward = False) Returns the item specified by item of either: - environment env, OR - trajectory traj_idx AND action number action_idx (in the order of sampling) If all arguments are given, then they must be consistent, otherwise an exception (assert) is raised due to ambiguity. If a mask is requested but is missing, it is computed and stored. :param item: String identifier of the item to retrieve from the batch. Options - state - parent - action - done - mask_f[orward] - mask_b[ackward] :type item: str :param traj_idx: Trajectory index :type traj_idx: int :param action_idx: Action index. Regardless of forward of backward, n-th item sampled when forming the batch. :type action_idx: int :param backward: Whether the trajectory is sampling backward. False (forward) by default. :type backward: bool :returns: * *The requested item if it is available or None if it is not. It raises an error* * *if the request can be identified as incorrect.* .. py:method:: get_indices_of_previous_transitions(envs, backward) Get batch indices of the latest elements (states, actions, etc) added to the batch for the provided environments. :param envs: List of envs for which the batch indices are requested. :type envs: list :param backward: Whether the trajectories are sampled backward (True) or forward (False). :type backward: bool :returns: *list* -- Batch indices of the previous transitions for the requested environments. Environments without any previous transitions in the batch are assigned None. .. py:method:: get_actions_of_previous_transitions(envs, backward) Retrieves the latest actions added to the batch for the provided envs. :param envs: List of envs for which the actions are requested. :type envs: list :param backward: Indicates whether the trajectories are sampled backward (True) or forward (False). :type backward: bool :returns: **actions** (*list*) -- Actions of previous transitions for the requested envs .. py:function:: compute_logprobs_trajectories(batch, env = None, forward_policy = None, backward_policy = None, backward = False) Computes the forward or backward log probabilities of the trajectories in a batch. The log probabilities are computed according to the forward or backward policy models passed as parameters. :param batch: A batch of data, containing all the states in the trajectories. :type batch: Batch :param env: An instance of the environment used to compute log probabilities of state transitions. If None, batch.readonly_env is used. :type env: :py:class:`gflownet.envs.base.GFlowNetEnv`, optional :param forward_policy: The model used to compute the forward policy outputs from input states. It is ignored if `backward` is True. :type forward_policy: :py:class:`gflownet.policy.base.Policy`, optional :param bacward_policy: Same as `forward_policy`, but for the backward policy. It is ignored if `backward` is False. :type bacward_policy: :py:class:`gflownet.policy.base.Policy`, optional :param backward: False: log probabilities of forward trajectories. True: log probabilities of backward trajectories. :type backward: bool :returns: **logprobs** (*torch.tensor*) -- The log probabilities of the trajectories.