Source code for gflownet.utils.metrics

import numpy as np
import torch
from sklearn.neighbors import KernelDensity


[docs] def fit_kde(samples, kernel="gaussian", bandwidth=0.1): """ :param samples: numpy array of shape [batch_size, n_dim] """ kde = KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(samples) return kde