gflownet.evaluator.base

Base evaluator class for a GFlowNetAgent.

In charge of evaluating a generic GFlowNetAgent, computing metrics plotting figures and optionally logging results using the GFlowNetAgent’s Logger.

Take this BaseEvaluator as example to implement your own evaluator class for your custom use-case.

Important

Prefer the from_dir() and from_agent() class methods to instantiate an evaluator.

See Using an Evaluator for more details about how to use an Evaluator.

Classes

BaseEvaluator

Base evaluator class for GFlowNetAgent.

Module Contents

class gflownet.evaluator.base.BaseEvaluator(gfn_agent=None, **config)[source]

Bases: gflownet.evaluator.abstract.AbstractEvaluator

Base evaluator class for GFlowNetAgent.

In particular, implements the eval() with:

And the plot() method with:

  • The GFlowNetEnv’s plot_reward_samples() method.

  • The GFlowNetEnv’s plot_kde() method if it exists, for both the kde_pred and kde_true arguments if they are returned in the "data" dict of the eval() method.

See the AbstractEvaluator for more details about other methods and attributes, including the __init__().

define_new_metrics()[source]

Method to be implemented by subclasses to define new metrics.

Example

def define_new_metrics(self):
    return {
        "my_custom_metric": {
            "display_name": "My custom metric",
            "requirements": ["density", "new_req"],
        }
    }
Returns:

dict – Dictionary of new metrics to add to the global METRICS dict.

eval_top_k(it, gfn_states=None, random_states=None)[source]

Sample from the current GFN and compute metrics and plots for the top k states according to both the energy and the reward.

Parameters:
  • it (int) – current iteration

  • gfn_states (list, optional) – Already sampled gfn states. Defaults to None.

  • random_states (list, optional) – Already sampled random states. Defaults to None.

Returns:

dict – Computed dict of metrics, and figures, and optionally (only once) summary metrics. Schema: {"metrics": {str: float}, "figs": {str: plt.Figure}, "summary": {str: float}}.

compute_log_prob_metrics(x_tt, metrics=None)[source]

Compute log-probability metrics for the given test data.

Uses estimate_logprobs_data().

Known metrics:

  • mean_logprobs_std: Mean of the standard deviation of the log-probabilities.

  • mean_probs_std: Mean of the standard deviation of the probabilities.

  • corr_probs_rewards: Correlation between the probabilities and the

    rewards.

  • corr_logprobs_logrewards: Correlation between the log-probabilities and

    the log-rewards.

  • var_logrewards_logp: Variance of the log-rewards minus the log-probabilities.

  • nll_tt: Negative log-likelihood of the test data.

  • logprobs_std_nll_ratio: Ratio of the mean of the standard deviation of the

    log-probabilities over the negative log-likelihood of the test data.

Returned data in the "data" sub-dict:

  • probs: Probabilities of the test data.

  • rewards: Rewards for the test data.

  • logprobs: Log-probabilities of the test data.

  • logrewards: Log-rewards of the test data.

Parameters:
  • x_tt (torch.Tensor) – Test data.

  • metrics (List[str], optional) – List of metrics to compute, by default None i.e. the evaluator’s self.metrics

Returns:

dict – Computed dict of metrics and data as {"metrics": {str: float}, "data": {str: object}}.

compute_density_metrics(x_tt, dict_tt, metrics=None)[source]

Compute density metrics for the given test data.

Known metrics:

  • l1: L1 error between the true and predicted densities.

  • kl: KL divergence between the true and predicted densities.

  • jsd: Jensen-Shannon divergence between the true and predicted densities.

Returned data in the "data" sub-dict:

  • x_sampled: Sampled states from the GFN.

  • kde_pred: KDE policy as per fit_kde().

  • kde_true: True KDE.

Parameters:
  • x_tt (torch.Tensor) – Test data.

  • dict_tt (dict) – Dictionary of test data.

  • metrics (List[str], optional) – List of metrics to compute, by default None i.e. the evaluator’s self.metrics

Returns:

dict – Computed dict of metrics and data as {"metrics": {str: float}, "data": {str: object}}.

eval(metrics=None, **plot_kwargs)[source]

Evaluate the GFlowNetAgent and compute metrics and plots.

If metrics is not provided, the evaluator’s self.metrics attribute is used (default).

Extand in subclasses to add more metrics and plots:

def eval(self, metrics=None, **plot_kwargs):
    result = super().eval(metrics=metrics, **plot_kwargs)
    result["metrics"]["my_custom_metric"] = my_custom_metric_function()
    result["figs"]["My custom plot"] = my_custom_plot_function()
    return result
Parameters:
  • metrics (List[str], optional) – List of metrics to compute, by default the evaluator’s self.metrics attribute.

  • plot_kwargs (dict, optional) – Additional keyword arguments to pass to the plotting methods.

Returns:

dict – Computed dict of metrics and figures as {"metrics": {str: float}, "figs": {str: plt.Figure}}.

plot(x_sampled=None, kde_pred=None, kde_true=None, plot_kwargs={}, **kwargs)[source]

Plots this evaluator should do, returned as a dict {str: plt.Figure} which will be logged.

By default, this method will call the following methods of the GFlowNetAgent’s environment if they exist:

  • plot_reward_samples

  • plot_kde (for both the kde_pred and kde_true arguments)

  • plot_samples_topk

Extend this method to add more plots:

def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs):
    figs = super().plot(x_sampled, kde_pred, kde_true, plot_kwargs)
    figs["My custom plot"] = my_custom_plot_function(x_sampled, kde_pred)
    return figs
Parameters:
  • x_sampled (list, optional) – List of sampled states.

  • kde_pred (sklearn.neighbors.KernelDensity) – KDE policy as per Environment.fit_kde

  • kde_true (object) – True KDE.

  • plot_kwargs (dict) – Additional keyword arguments to pass to the plotting methods.

  • kwargs (dict) – Catch-all for additional arguments.

Returns:

dict[str, plt.Figure] – Dictionary of figures to be logged. The keys are the figure names and the values are the figures.