gflownet.policy.multihead_tree ============================== .. py:module:: gflownet.policy.multihead_tree Classes ------- .. autoapisummary:: gflownet.policy.multihead_tree.Backbone gflownet.policy.multihead_tree.LeafSelectionHead gflownet.policy.multihead_tree.FeatureSelectionHead gflownet.policy.multihead_tree.ThresholdSelectionHead gflownet.policy.multihead_tree.OperatorSelectionHead gflownet.policy.multihead_tree.ForwardTreeModel gflownet.policy.multihead_tree.BackwardTreeModel gflownet.policy.multihead_tree.MultiheadTreePolicy Module Contents --------------- .. py:class:: Backbone(input_dim, n_layers = 3, hidden_dim = 64, layer = 'GCNConv', activation = 'LeakyReLU') Bases: :py:obj:`torch.nn.Module` GNN backbone: a stack of GNN layers that can be used for processing graphs. .. py:attribute:: hidden_dim :value: 64 .. py:attribute:: model .. py:method:: forward(data) .. py:class:: LeafSelectionHead(backbone, max_nodes, model_eos = True, n_layers = 2, hidden_dim = 64, layer = 'GCNConv', activation = 'LeakyReLU') Bases: :py:obj:`torch.nn.Module` Node-level prediction head. Consists of a stack of GNN layers, and if ``model_eos`` is True, a separate linear layer for modeling the exit action. Note that in the forward function a conversion from the node-level predictions to an expected vector policy output is being done. Because of that, the output is a regular tensor (with logits at correct positions, regardless of the graph shape). .. py:attribute:: max_nodes .. py:attribute:: model_eos :value: True .. py:attribute:: backbone .. py:attribute:: body .. py:attribute:: leaf_head_layer .. py:method:: forward(data) .. py:class:: FeatureSelectionHead(backbone, input_dim, output_dim, n_layers = 2, hidden_dim = 64, activation = 'LeakyReLU') Bases: :py:obj:`torch.nn.Module` A graph-level prediction head that pools the representations from the backbone, and passes them through an MLP. Expected to have the output dimensionality equal to the number of available features. .. py:attribute:: backbone .. py:attribute:: model .. py:method:: forward(data) .. py:class:: ThresholdSelectionHead(backbone, input_dim, output_dim, n_layers = 2, hidden_dim = 64, activation = 'LeakyReLU') Bases: :py:obj:`torch.nn.Module` A graph-level prediction head that pools the representations from the backbone, and passes them through an MLP. Expected to have output dimensionality equal to the number of available features plus one, with the last element being the features that were selected in the previous stage (which are concatenated with the pooled graph representation). .. py:attribute:: backbone .. py:attribute:: model .. py:method:: forward(data, feature_index) .. py:class:: OperatorSelectionHead(backbone, input_dim, n_layers = 2, hidden_dim = 64, activation = 'LeakyReLU') Bases: :py:obj:`torch.nn.Module` A graph-level prediction head that pools the representations from the backbone, and passes them through an MLP. Expected to have output dimensionality equal to the number of available features plus two, with the last two elements being the features and the thresholds that were selected in the previous stage (which are concatenated with the pooled graph representation). .. py:attribute:: backbone .. py:attribute:: model .. py:method:: forward(data, feature_index, threshold) .. py:class:: ForwardTreeModel(continuous, n_features, policy_output_dim, leaf_index, feature_index, threshold_index, operator_index, eos_index, base = None, backbone_args = None, leaf_head_args = None, feature_head_args = None, threshold_head_args = None, operator_head_args = None) Bases: :py:obj:`torch.nn.Module` A model that combines the backbone and several output heads, which will be used depending on the current stage of the passed state. .. py:attribute:: continuous .. py:attribute:: n_features .. py:attribute:: policy_output_dim .. py:attribute:: leaf_index .. py:attribute:: feature_index .. py:attribute:: threshold_index .. py:attribute:: operator_index .. py:attribute:: eos_index .. py:attribute:: leaf_head .. py:attribute:: feature_head .. py:attribute:: threshold_head .. py:attribute:: operator_head .. py:method:: forward(x) .. py:class:: BackwardTreeModel(continuous, n_features, policy_output_dim, leaf_index, feature_index, threshold_index, operator_index, eos_index, base = None, backbone_args = None, leaf_head_args = None) Bases: :py:obj:`torch.nn.Module` A model that combines the backbone and several output heads, which will be used depending on the current stage of the passed state. In contrast to the ForwardTreeModel has less output heads, as some of the backward transitions are deterministic. .. py:attribute:: continuous .. py:attribute:: n_features .. py:attribute:: policy_output_dim .. py:attribute:: leaf_index .. py:attribute:: feature_index .. py:attribute:: threshold_index .. py:attribute:: operator_index .. py:attribute:: eos_index .. py:attribute:: complete_stage_head .. py:attribute:: leaf_stage_head .. py:method:: forward(x) .. py:class:: MultiheadTreePolicy(config, env, device, float_precision, base=None) Bases: :py:obj:`gflownet.policy.base.Policy` Policy wrapper using ForwardTreeModel and BackwardTreeModel as the policy models. .. py:attribute:: backbone_args .. py:attribute:: leaf_head_args .. py:attribute:: feature_head_args .. py:attribute:: operator_head_args .. py:attribute:: continuous .. py:attribute:: n_features .. py:attribute:: policy_output_dim .. py:attribute:: leaf_index .. py:attribute:: feature_index .. py:attribute:: threshold_index .. py:attribute:: operator_index .. py:attribute:: eos_index .. py:attribute:: is_model :value: True .. py:method:: parse_config(config) .. py:method:: instantiate() .. py:method:: __call__(states)