gflownet.envs.torus =================== .. py:module:: gflownet.envs.torus .. autoapi-nested-parse:: IMPORTANT: this environment is currently broken! Classes to represent hyper-torus environments Classes ------- .. autoapisummary:: gflownet.envs.torus.Torus Module Contents --------------- .. py:class:: Torus(n_dim = 2, n_angles = 3, length_traj = 1, max_increment = 1, max_dim_per_action = 1, **kwargs) Bases: :py:obj:`gflownet.envs.base.GFlowNetEnv` Hyper-torus environment in which the action space consists of: - Increasing the angle index of dimension d - Decreasing the angle index of dimension d - Keeping all dimensions as are and the trajectory is of fixed length length_traj. The states space is the concatenation of the angle index at each dimension and the number of actions. .. attribute:: ndim Dimensionality of the torus :type: int .. attribute:: n_angles Number of angles into which each dimension is divided :type: int .. attribute:: length_traj Fixed length of the trajectory. :type: int .. py:attribute:: n_dim :value: 2 .. py:attribute:: n_angles :value: 3 .. py:attribute:: length_traj :value: 1 .. py:attribute:: max_increment :value: 1 .. py:attribute:: max_dim_per_action :value: 1 .. py:attribute:: source_angles .. py:attribute:: source .. py:attribute:: eos .. py:attribute:: angle_rad .. py:method:: get_action_space() Constructs list with all possible actions, including eos. An action is represented by a vector of length n_dim where each index d indicates the increment/decrement to apply to dimension d of the hyper-torus. A negative value indicates a decrement. The action "keep" (no increment/decrement of any dimensions) is valid and is indicated by all zeros. .. py:method:: get_mask_invalid_actions_forward(state = None, done = None) Returns a list of length the action space with values: - True if the forward action is invalid from the current state. - False otherwise. All actions except EOS are valid if the maximum number of actions has not been reached, and vice versa. .. py:method:: states2proxy(states) Prepares a batch of states in "environment format" for the proxy: each state is a vector of length n_dim where each value is an angle in radians. The n_actions item is removed. :param states: A batch of states in environment format, either as a list of states or as a single tensor. :type states: list or tensor :returns: *A tensor containing all the states in the batch.* .. py:method:: states2policy(states) Prepares a batch of states in "environment format" for the policy model: the policy format is a one-hot encoding of the states. Each row is a vector of length n_angles * n_dim + 1, where each n-th successive block of length elements is a one-hot encoding of the position in the n-th dimension. Example, n_dim = 2, n_angles = 4: - state: [1, 3, 4] | a | n | (a = angles, n = n_actions) - policy format: [0, 1, 0, 0, 0, 0, 0, 1, 4] | 1 | 3 | 4 | :param states: A batch of states in environment format, either as a list of states or as a single tensor. :type states: list or tensor :returns: *A tensor containing all the states in the batch.* .. py:method:: state2readable(state = None) Converts a state (a list of positions) into a human-readable string representing a state. .. py:method:: readable2state(readable) Converts a human-readable string representing a state into a state as a list of positions. .. py:method:: get_parents(state = None, done = None, action = None) Determines all parents and actions that lead to state. :param state: Representation of a state, as a list of length n_angles where each element is the position at each dimension. :type state: list :param done: Whether the trajectory is done. If None, done is taken from instance. :type done: bool :param action: Ignored :type action: None :returns: * **parents** (*list*) -- List of parents in state format * **actions** (*list*) -- List of actions that lead to state for each parent in parents .. py:method:: step(action, skip_mask_check = False) Executes step given an action. :param action: Action to be executed. See: get_action_space() :type action: tuple :param skip_mask_check: If True, skip computing forward mask of invalid actions to check if the action is valid. :type skip_mask_check: bool :returns: * **self.state** (*list*) -- The sequence after executing the action * **action** (*tuple*) -- Action executed * **valid** (*bool*) -- False, if the action is not allowed for the current state. .. py:method:: get_all_terminating_states() .. py:method:: fit_kde(kernel='exponential', bandwidth=0.1)