gflownet.evaluator.base
Base evaluator class for a GFlowNetAgent.
In charge of evaluating a generic GFlowNetAgent,
computing metrics plotting figures and optionally logging results using the
GFlowNetAgent’s Logger.
Take this BaseEvaluator as example to implement your own evaluator class
for your custom use-case.
Important
Prefer the from_dir()
and from_agent()
class methods to instantiate an evaluator.
See Using an Evaluator for more details about how to use an Evaluator.
Classes
Base evaluator class for GFlowNetAgent. |
Module Contents
- class gflownet.evaluator.base.BaseEvaluator(gfn_agent=None, **config)[source]
Bases:
gflownet.evaluator.abstract.AbstractEvaluatorBase evaluator class for GFlowNetAgent.
In particular, implements the
eval()with:compute_log_prob_metrics()to compute log-probability metrics.compute_density_metrics()to compute density metrics.
And the
plot()method with:The
GFlowNetEnv’splot_reward_samples()method.The
GFlowNetEnv’splot_kde()method if it exists, for both thekde_predandkde_truearguments if they are returned in the"data"dict of theeval()method.
See the
AbstractEvaluatorfor more details about other methods and attributes, including the__init__().- define_new_metrics()[source]
Method to be implemented by subclasses to define new metrics.
Example
def define_new_metrics(self): return { "my_custom_metric": { "display_name": "My custom metric", "requirements": ["density", "new_req"], } }
- Returns:
dict – Dictionary of new metrics to add to the global
METRICSdict.
- eval_top_k(it, gfn_states=None, random_states=None)[source]
Sample from the current GFN and compute metrics and plots for the top k states according to both the energy and the reward.
- Parameters:
it (int) – current iteration
gfn_states (list, optional) – Already sampled gfn states. Defaults to None.
random_states (list, optional) – Already sampled random states. Defaults to None.
- Returns:
dict – Computed dict of metrics, and figures, and optionally (only once) summary metrics. Schema:
{"metrics": {str: float}, "figs": {str: plt.Figure}, "summary": {str: float}}.
- compute_log_prob_metrics(x_tt, metrics=None)[source]
Compute log-probability metrics for the given test data.
Uses
estimate_logprobs_data().Known metrics:
mean_logprobs_std: Mean of the standard deviation of the log-probabilities.mean_probs_std: Mean of the standard deviation of the probabilities.corr_probs_rewards: Correlation between the probabilities and therewards.
corr_logprobs_logrewards: Correlation between the log-probabilities andthe log-rewards.
var_logrewards_logp: Variance of the log-rewards minus the log-probabilities.nll_tt: Negative log-likelihood of the test data.logprobs_std_nll_ratio: Ratio of the mean of the standard deviation of thelog-probabilities over the negative log-likelihood of the test data.
Returned data in the
"data"sub-dict:probs: Probabilities of the test data.rewards: Rewards for the test data.logprobs: Log-probabilities of the test data.logrewards: Log-rewards of the test data.
- Parameters:
x_tt (torch.Tensor) – Test data.
metrics (List[str], optional) – List of metrics to compute, by default
Nonei.e. the evaluator’sself.metrics
- Returns:
dict – Computed dict of metrics and data as
{"metrics": {str: float}, "data": {str: object}}.
- compute_density_metrics(x_tt, dict_tt, metrics=None)[source]
Compute density metrics for the given test data.
Known metrics:
l1: L1 error between the true and predicted densities.kl: KL divergence between the true and predicted densities.jsd: Jensen-Shannon divergence between the true and predicted densities.
Returned data in the
"data"sub-dict:x_sampled: Sampled states from the GFN.kde_pred: KDE policy as perfit_kde().kde_true: True KDE.
- Parameters:
x_tt (torch.Tensor) – Test data.
dict_tt (dict) – Dictionary of test data.
metrics (List[str], optional) – List of metrics to compute, by default
Nonei.e. the evaluator’sself.metrics
- Returns:
dict – Computed dict of metrics and data as
{"metrics": {str: float}, "data": {str: object}}.
- eval(metrics=None, **plot_kwargs)[source]
Evaluate the GFlowNetAgent and compute metrics and plots.
If
metricsis not provided, the evaluator’sself.metricsattribute is used (default).Extand in subclasses to add more metrics and plots:
def eval(self, metrics=None, **plot_kwargs): result = super().eval(metrics=metrics, **plot_kwargs) result["metrics"]["my_custom_metric"] = my_custom_metric_function() result["figs"]["My custom plot"] = my_custom_plot_function() return result
- Parameters:
metrics (List[str], optional) – List of metrics to compute, by default the evaluator’s
self.metricsattribute.plot_kwargs (dict, optional) – Additional keyword arguments to pass to the plotting methods.
- Returns:
dict – Computed dict of metrics and figures as
{"metrics": {str: float}, "figs": {str: plt.Figure}}.
- plot(x_sampled=None, kde_pred=None, kde_true=None, plot_kwargs={}, **kwargs)[source]
Plots this evaluator should do, returned as a dict
{str: plt.Figure}which will be logged.By default, this method will call the following methods of the GFlowNetAgent’s environment if they exist:
plot_reward_samplesplot_kde(for both thekde_predandkde_truearguments)plot_samples_topk
Extend this method to add more plots:
def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs): figs = super().plot(x_sampled, kde_pred, kde_true, plot_kwargs) figs["My custom plot"] = my_custom_plot_function(x_sampled, kde_pred) return figs
- Parameters:
x_sampled (list, optional) – List of sampled states.
kde_pred (sklearn.neighbors.KernelDensity) – KDE policy as per
Environment.fit_kdekde_true (object) – True KDE.
plot_kwargs (dict) – Additional keyword arguments to pass to the plotting methods.
kwargs (dict) – Catch-all for additional arguments.
- Returns:
dict[str, plt.Figure] – Dictionary of figures to be logged. The keys are the figure names and the values are the figures.