gflownet.policy.simple_tree
Classes
GNN backbone: a stack of GNN layers that can be used for processing graphs. |
|
GNN head: a combination of a pooling layer and a stack of linear layers, |
|
Combination of backbone and head models. In forward, it converts the states |
|
Policy wrapper using SimpleTreeModel as the policy model. |
Module Contents
- class gflownet.policy.simple_tree.Backbone(input_dim, n_layers=3, hidden_dim=128, layer='GINEConv', activation='LeakyReLU')[source]
Bases:
torch.nn.ModuleGNN backbone: a stack of GNN layers that can be used for processing graphs.
- Parameters:
input_dim (int)
n_layers (int)
hidden_dim (int)
layer (str)
activation (str)
- class gflownet.policy.simple_tree.Head(backbone, out_dim, n_layers=2, hidden_dim=256, activation='LeakyReLU')[source]
Bases:
torch.nn.ModuleGNN head: a combination of a pooling layer and a stack of linear layers, that takes the input graphs processed by the provided backbone and outputs
out_dimvalues.This is a naive variant of a Tree policy model, that in particular doesn’t do any node-level prediction, but instead predicts the whole policy output all at once.
- Parameters:
backbone (torch.nn.Module)
out_dim (int)
n_layers (int)
hidden_dim (int)
activation (str)
- class gflownet.policy.simple_tree.SimpleTreeModel(n_features, policy_output_dim, base=None, backbone_args=None, head_args=None)[source]
Bases:
torch.nn.ModuleCombination of backbone and head models. In forward, it converts the states to a batch of graphs, and passes them through both models.
- Parameters:
n_features (int)
policy_output_dim (int)
base (Optional[SimpleTreePolicy])
backbone_args (Optional[dict])
head_args (Optional[dict])
- class gflownet.policy.simple_tree.SimpleTreePolicy(config, env, device, float_precision, base=None)[source]
Bases:
gflownet.policy.base.PolicyPolicy wrapper using SimpleTreeModel as the policy model.