gflownet.buffer.base

Base Buffer class to handle train and test data sets, reply buffer, etc.

Classes

BaseBuffer

Initializes the Buffer.

Module Contents

class gflownet.buffer.base.BaseBuffer(env, proxy, datadir, replay_buffer=None, replay_capacity=0, train=None, test=None, use_main_buffer=False, check_diversity=False, diversity_check_reward_similarity=0.1, progress_process_dataset=False, **kwargs)[source]

Initializes the Buffer.

Parameters:
  • datadir (str or PosixPath) – The directory where the data sets and buffers are stored. By default, it is ./data/ but it is first set by the logger and passed as an argument to the Buffer for consistency, especially to handle resumed runs.

  • replay_buffer (str or PosixPath) – A path to a file containing a replay buffer. If provided, the initial replay buffer will be loaded from this file. This is useful for for resuming runs. By default it is None, which initializes an empty buffer and creates a new file.

  • replay_capacity (int) – Size of the replay buffer. By default, it is zero, thus no replay buffer is used.

  • train (dict) –

    A dictionary describing the training data. The dictionary can have the following keys:

    • typestr
      Type of data. It can be one of the following:
      • pkl: a pickled file. Requires path.

      • csv: a CSV file. Requires path.

      • all: all terminating states of the environment.

      • grid: a grid of terminating states. Requires n.

      • uniform: terminating states uniformly sampled. Requires n.

      • random: terminating states sampled randomly from the intial GFN policy. Requires n.

    • pathstr

      Path to a CSV of pickled file (for type={pkl, csv})

    • nint

      Number of samples (for type={grid, uniform, random})

    • seedint

      Seed for random sampling (for type={uniform, random})

  • test (dict) – A dictionary describing the test data. The dictionary is akin the train dictionarity.

  • use_main_buffer (bool) – If True, a main buffer is kept up to date, that is all training samples are added to a buffer. It is False by default because of the potentially large memory usage it can incur.

  • check_diversity (bool) – If True, new samples are only added to the buffer if they are not close to any of the samples already present in the buffer. env.isclose() is used for the comparison. It is False by default because this comparison can easily take most of the running time with an uncertain impact on the performance. The implementation should be improved to make this functional.

  • diversity_check_reward_similarity (float) – The accepted level of similarity of rewards to include samples from the replay buffer in the diversity check. Assuming check_diversity is True, given a sample x with reward R(x), the diversity check will only be performed against those samples in the replay buffer whose reward difference with respect to R(x) is smaller than diversity_check_reward_similarity times the difference between the maximum reward and the minimum reward in the replay buffer. By default, it is 0.1. If the value is -1 (or smaller than 0.0), then the diversity check will be done with the full replay buffer. Note too that a value of 0.0 is equivalent to not doing any diversity check at all.

  • progress_process_dataset (bool) – Whether to show a progress bar while processing the data sets. False by default.

datadir[source]
env[source]
proxy[source]
replay_capacity = 0[source]
train_config = None[source]
test_config = None[source]
use_main_buffer = False[source]
check_diversity = False[source]
diversity_check_reward_similarity = 0.1[source]
progress_process_dataset = False[source]
replay_updated = False[source]
init_replay(replay_buffer_path=None)[source]

Initializes the replay buffer.

If a path to an existing replay buffer file is provided, then the replay buffer is initialized from it. Otherwise, a new empty buffer is created.

Parameters:
  • replay_buffer (str or PosixPath) – A path to a file containing a replay buffer. If provided, the initial replay buffer will be loaded from this file. This is useful for for resuming runs. By default it is None, which initializes an empty buffer and creates a new file.

  • replay_buffer_path (Union[str, pathlib.PosixPath])

Returns:

  • replay (pandas.DataFrame) – DataFrame with the initial replay buffer.

  • replay_csv (PosixPath) – Path of the CSV that will store the replay buffer.

property replay_samples[source]
property replay_trajectories[source]
property replay_rewards[source]
save_replay()[source]
load_replay_from_path(path=None)[source]

Loads a replay buffer stored as a CSV file.

Parameters:

path (pathlib.PosixPath)

add(samples, trajectories, rewards, it, buffer='main', criterion='greater')[source]

Adds a batch of samples (with the trajectory actions and rewards) to the buffer.

Parameters:
  • samples (list) – A batch of terminating states.

  • trajectories (list) – The list of trajectory actions of each terminating state.

  • rewards (list or tensor) – The reward of each terminating state.

  • it (int) – Iteration number.

  • buffer (str) – Identifier of the buffer: main or replay

  • criterion (str) – Identifier of the criterion. Currently, only greater is implemented.

make_data_set(config)[source]

Constructs a data set as a DataFrame according to the configuration.

static compute_stats(data)[source]
static select(df, n, mode='permutation', rng=None)[source]

Selects a subset of n data points from data_dict, according to the criterion indicated by mode.

The data dict may be a training set or a replay buffer.

The mode argument can be one of the following:
  • permutation: data points are sampled uniformly from the dictionary, without replacement, using the random generator rng.

  • uniform: data points are sampled uniformly from the dictionary, with replacement, using the random generator rng.

  • weighted: data points are sampled with probability proportional to their score.

Parameters:
  • data_dict (dict) – A dictionary containing data for various data samples. The keys of the dictionary represent the sample attributes and the values are lists that contain the values of these attributes for all the samples. All the values in the data dictionary should have the same length. If mode == “weighted”, the data dictionary must contain sample scores (key “scores” or “rewards”).

  • n (int) – The number of samples to select from the dictionary.

  • mode (str) – Sampling mode. Options: permutation, weighted.

  • rng (np.random.Generator) – A numpy random number generator, used for the permutation mode. Ignored otherwise.

  • df (pandas.DataFrame)

Returns:

filtered_data_dict – A dict containing the data of n samples, selected from data_dict.

Return type:

pandas.DataFrame