gflownet.policy.multihead_tree
Classes
GNN backbone: a stack of GNN layers that can be used for processing graphs. |
|
Node-level prediction head. Consists of a stack of GNN layers, and if |
|
A graph-level prediction head that pools the representations from the |
|
A graph-level prediction head that pools the representations from the |
|
A graph-level prediction head that pools the representations from the |
|
A model that combines the backbone and several output heads, which |
|
A model that combines the backbone and several output heads, which |
|
Policy wrapper using ForwardTreeModel and BackwardTreeModel as the policy models. |
Module Contents
- class gflownet.policy.multihead_tree.Backbone(input_dim, n_layers=3, hidden_dim=64, layer='GCNConv', activation='LeakyReLU')[source]
Bases:
torch.nn.ModuleGNN backbone: a stack of GNN layers that can be used for processing graphs.
- Parameters:
input_dim (int)
n_layers (int)
hidden_dim (int)
layer (str)
activation (str)
- class gflownet.policy.multihead_tree.LeafSelectionHead(backbone, max_nodes, model_eos=True, n_layers=2, hidden_dim=64, layer='GCNConv', activation='LeakyReLU')[source]
Bases:
torch.nn.ModuleNode-level prediction head. Consists of a stack of GNN layers, and if
model_eosis True, a separate linear layer for modeling the exit action.Note that in the forward function a conversion from the node-level predictions to an expected vector policy output is being done. Because of that, the output is a regular tensor (with logits at correct positions, regardless of the graph shape).
- Parameters:
backbone (torch.nn.Module)
max_nodes (int)
model_eos (bool)
n_layers (int)
hidden_dim (int)
layer (str)
activation (str)
- class gflownet.policy.multihead_tree.FeatureSelectionHead(backbone, input_dim, output_dim, n_layers=2, hidden_dim=64, activation='LeakyReLU')[source]
Bases:
torch.nn.ModuleA graph-level prediction head that pools the representations from the backbone, and passes them through an MLP.
Expected to have the output dimensionality equal to the number of available features.
- Parameters:
backbone (torch.nn.Module)
input_dim (int)
output_dim (int)
n_layers (int)
hidden_dim (int)
activation (str)
- class gflownet.policy.multihead_tree.ThresholdSelectionHead(backbone, input_dim, output_dim, n_layers=2, hidden_dim=64, activation='LeakyReLU')[source]
Bases:
torch.nn.ModuleA graph-level prediction head that pools the representations from the backbone, and passes them through an MLP.
Expected to have output dimensionality equal to the number of available features plus one, with the last element being the features that were selected in the previous stage (which are concatenated with the pooled graph representation).
- Parameters:
backbone (torch.nn.Module)
input_dim (int)
output_dim (int)
n_layers (int)
hidden_dim (int)
activation (str)
- forward(data, feature_index)[source]
- Parameters:
data (torch_geometric.data.Data)
feature_index (torch.Tensor)
- Return type:
- class gflownet.policy.multihead_tree.OperatorSelectionHead(backbone, input_dim, n_layers=2, hidden_dim=64, activation='LeakyReLU')[source]
Bases:
torch.nn.ModuleA graph-level prediction head that pools the representations from the backbone, and passes them through an MLP.
Expected to have output dimensionality equal to the number of available features plus two, with the last two elements being the features and the thresholds that were selected in the previous stage (which are concatenated with the pooled graph representation).
- Parameters:
backbone (torch.nn.Module)
input_dim (int)
n_layers (int)
hidden_dim (int)
activation (str)
- forward(data, feature_index, threshold)[source]
- Parameters:
data (torch_geometric.data.Data)
feature_index (torch.Tensor)
threshold (torch.Tensor)
- Return type:
- class gflownet.policy.multihead_tree.ForwardTreeModel(continuous, n_features, policy_output_dim, leaf_index, feature_index, threshold_index, operator_index, eos_index, base=None, backbone_args=None, leaf_head_args=None, feature_head_args=None, threshold_head_args=None, operator_head_args=None)[source]
Bases:
torch.nn.ModuleA model that combines the backbone and several output heads, which will be used depending on the current stage of the passed state.
- Parameters:
continuous (bool)
n_features (int)
policy_output_dim (int)
leaf_index (int)
feature_index (int)
threshold_index (int)
operator_index (int)
eos_index (int)
base (Optional[MultiheadTreePolicy])
backbone_args (Optional[dict])
leaf_head_args (Optional[dict])
feature_head_args (Optional[dict])
threshold_head_args (Optional[dict])
operator_head_args (Optional[dict])
- class gflownet.policy.multihead_tree.BackwardTreeModel(continuous, n_features, policy_output_dim, leaf_index, feature_index, threshold_index, operator_index, eos_index, base=None, backbone_args=None, leaf_head_args=None)[source]
Bases:
torch.nn.ModuleA model that combines the backbone and several output heads, which will be used depending on the current stage of the passed state.
In contrast to the ForwardTreeModel has less output heads, as some of the backward transitions are deterministic.
- Parameters:
continuous (bool)
n_features (int)
policy_output_dim (int)
leaf_index (int)
feature_index (int)
threshold_index (int)
operator_index (int)
eos_index (int)
base (Optional[MultiheadTreePolicy])
backbone_args (Optional[dict])
leaf_head_args (Optional[dict])
- class gflownet.policy.multihead_tree.MultiheadTreePolicy(config, env, device, float_precision, base=None)[source]
Bases:
gflownet.policy.base.PolicyPolicy wrapper using ForwardTreeModel and BackwardTreeModel as the policy models.