from abc import ABC, abstractmethod
import torch
from omegaconf import OmegaConf
from torch import nn
from gflownet.utils.common import set_device, set_float_precision
[docs]
class ModelBase(ABC):
def __init__(self, config, env, device, float_precision, base=None):
# Device and float precision
[docs]
self.device = set_device(device)
[docs]
self.float = set_float_precision(float_precision)
# Input and output dimensions
[docs]
self.state_dim = env.policy_input_dim
[docs]
self.fixed_output = env.fixed_policy_output
[docs]
self.random_output = env.random_policy_output
[docs]
self.output_dim = len(self.fixed_output)
# Optional base model
self.parse_config(config)
[docs]
def parse_config(self, config):
# If config is null, default to uniform
if config is None:
config = OmegaConf.create()
config.type = "uniform"
self.checkpoint = config.get("checkpoint", None)
self.shared_weights = config.get("shared_weights", False)
self.n_hid = config.get("n_hid", None)
self.n_layers = config.get("n_layers", None)
self.tail = config.get("tail", [])
if "type" in config:
self.type = config.type
elif self.shared_weights:
self.type = self.base.type
else:
raise "Policy type must be defined if shared_weights is False"
@abstractmethod
[docs]
def instantiate(self):
pass
[docs]
def __call__(self, states):
return self.model(states)
[docs]
def make_mlp(self, activation):
"""
Defines an MLP with no top layer activation
If share_weight == True,
baseModel (the model with which weights are to be shared) must be provided
Args
----
layers_dim : list
Dimensionality of each layer
activation : Activation
Activation function
"""
if self.shared_weights == True and self.base is not None:
mlp = nn.Sequential(
self.base.model[:-1],
nn.Linear(
self.base.model[-1].in_features,
self.base.model[-1].out_features,
dtype=self.float,
),
)
return mlp
elif self.shared_weights == False:
layers_dim = (
[self.state_dim] + [self.n_hid] * self.n_layers + [self.output_dim]
)
mlp = nn.Sequential(
*(
sum(
[
[
nn.Linear(
idim,
odim,
dtype=self.float,
)
]
+ ([activation] if n < len(layers_dim) - 2 else [])
for n, (idim, odim) in enumerate(
zip(layers_dim, layers_dim[1:])
)
],
[],
)
+ self.tail
)
)
return mlp
else:
raise ValueError(
"Base Model must be provided when shared_weights is set to True"
)
[docs]
class Policy(ModelBase):
def __init__(self, config, env, device, float_precision, base=None):
super().__init__(config, env, device, float_precision, base)
self.instantiate()
[docs]
def instantiate(self):
if self.type == "fixed":
self.model = self.fixed_distribution
self.is_model = False
elif self.type == "uniform":
self.model = self.uniform_distribution
self.is_model = False
elif self.type == "mlp":
self.model = self.make_mlp(nn.LeakyReLU()).to(self.device)
self.is_model = True
else:
raise "Policy model type not defined"
[docs]
def fixed_distribution(self, states):
"""
Returns the fixed distribution specified by the environment.
Parameters
----------
states : tensor
The states for which the fixed distribution is to be returned
"""
return torch.tile(self.fixed_output, (len(states), 1)).to(
dtype=self.float, device=self.device
)
[docs]
def random_distribution(self, states):
"""
Returns the random distribution specified by the environment.
Parameters
----------
states : tensor
The states for which the random distribution is to be returned
"""
return torch.tile(self.random_output, (len(states), 1)).to(
dtype=self.float, device=self.device
)