Source code for gflownet.utils.molecule.distributions
from pyro.distributions import ProjectedNormal
from torch.distributions.categorical import Categorical
from torch.distributions.mixture_same_family import MixtureSameFamily
[docs]
def get_mixture_of_projected_normals(weights, concentrations):
"""
:param weights: torch.tensor of shape [*batch_shape, n_components]
:param concentrations: torch.tensor of shape [*batch_shape, n_components, n_dim]
"""
mix = Categorical(weights)
comp = ProjectedNormal(concentrations)
return MixtureSameFamily(mix, comp)