gflownet.policy.simple_tree =========================== .. py:module:: gflownet.policy.simple_tree Classes ------- .. autoapisummary:: gflownet.policy.simple_tree.Backbone gflownet.policy.simple_tree.Head gflownet.policy.simple_tree.SimpleTreeModel gflownet.policy.simple_tree.SimpleTreePolicy Module Contents --------------- .. py:class:: Backbone(input_dim, n_layers = 3, hidden_dim = 128, layer = 'GINEConv', activation = 'LeakyReLU') Bases: :py:obj:`torch.nn.Module` GNN backbone: a stack of GNN layers that can be used for processing graphs. .. py:attribute:: hidden_dim :value: 128 .. py:attribute:: model .. py:method:: forward(data) .. py:class:: Head(backbone, out_dim, n_layers = 2, hidden_dim = 256, activation = 'LeakyReLU') Bases: :py:obj:`torch.nn.Module` GNN head: a combination of a pooling layer and a stack of linear layers, that takes the input graphs processed by the provided backbone and outputs ``out_dim`` values. This is a naive variant of a Tree policy model, that in particular doesn't do any node-level prediction, but instead predicts the whole policy output all at once. .. py:attribute:: backbone .. py:attribute:: head .. py:method:: forward(data) .. py:class:: SimpleTreeModel(n_features, policy_output_dim, base = None, backbone_args = None, head_args = None) Bases: :py:obj:`torch.nn.Module` Combination of backbone and head models. In forward, it converts the states to a batch of graphs, and passes them through both models. .. py:attribute:: n_features .. py:attribute:: policy_output_dim .. py:attribute:: model .. py:method:: forward(x) .. py:class:: SimpleTreePolicy(config, env, device, float_precision, base=None) Bases: :py:obj:`gflownet.policy.base.Policy` Policy wrapper using SimpleTreeModel as the policy model. .. py:attribute:: backbone_args .. py:attribute:: head_args .. py:attribute:: n_features .. py:attribute:: policy_output_dim .. py:attribute:: is_model :value: True .. py:method:: parse_config(config) .. py:method:: instantiate() .. py:method:: __call__(states)