gflownet.policy.simple_tree

Classes

Backbone

GNN backbone: a stack of GNN layers that can be used for processing graphs.

Head

GNN head: a combination of a pooling layer and a stack of linear layers,

SimpleTreeModel

Combination of backbone and head models. In forward, it converts the states

SimpleTreePolicy

Policy wrapper using SimpleTreeModel as the policy model.

Module Contents

class gflownet.policy.simple_tree.Backbone(input_dim, n_layers=3, hidden_dim=128, layer='GINEConv', activation='LeakyReLU')[source]

Bases: torch.nn.Module

GNN backbone: a stack of GNN layers that can be used for processing graphs.

Parameters:
  • input_dim (int)

  • n_layers (int)

  • hidden_dim (int)

  • layer (str)

  • activation (str)

hidden_dim = 128[source]
model[source]
forward(data)[source]
Parameters:

data (torch_geometric.data.Data)

Return type:

torch.Tensor

class gflownet.policy.simple_tree.Head(backbone, out_dim, n_layers=2, hidden_dim=256, activation='LeakyReLU')[source]

Bases: torch.nn.Module

GNN head: a combination of a pooling layer and a stack of linear layers, that takes the input graphs processed by the provided backbone and outputs out_dim values.

This is a naive variant of a Tree policy model, that in particular doesn’t do any node-level prediction, but instead predicts the whole policy output all at once.

Parameters:
  • backbone (torch.nn.Module)

  • out_dim (int)

  • n_layers (int)

  • hidden_dim (int)

  • activation (str)

backbone[source]
head[source]
forward(data)[source]
Parameters:

data (torch_geometric.data.Data)

Return type:

(torch.Tensor, torch.Tensor)

class gflownet.policy.simple_tree.SimpleTreeModel(n_features, policy_output_dim, base=None, backbone_args=None, head_args=None)[source]

Bases: torch.nn.Module

Combination of backbone and head models. In forward, it converts the states to a batch of graphs, and passes them through both models.

Parameters:
  • n_features (int)

  • policy_output_dim (int)

  • base (Optional[SimpleTreePolicy])

  • backbone_args (Optional[dict])

  • head_args (Optional[dict])

n_features[source]
policy_output_dim[source]
model[source]
forward(x)[source]
class gflownet.policy.simple_tree.SimpleTreePolicy(config, env, device, float_precision, base=None)[source]

Bases: gflownet.policy.base.Policy

Policy wrapper using SimpleTreeModel as the policy model.

backbone_args[source]
head_args[source]
n_features[source]
policy_output_dim[source]
is_model = True[source]
parse_config(config)[source]
instantiate()[source]
__call__(states)[source]