Source code for gflownet.utils.molecule.atom_positions_dataset

import numpy as np

from gflownet.utils.common import download_file_if_not_exists


[docs] class AtomPositionsDataset: def __init__(self, path_to_data, url_to_data): path_to_data = download_file_if_not_exists(path_to_data, url_to_data)
[docs] self.positions = np.load(path_to_data)
[docs] def __getitem__(self, i): return self.positions[i]
[docs] def __len__(self): return self.positions.shape[0]
[docs] def sample(self, size=None): idx = np.random.randint(0, len(self), size=size) return self.positions[idx]