"""
Base evaluator class for a :class:`~gflownet.gflownet.GFlowNetAgent`.
In charge of evaluating a generic :class:`~gflownet.gflownet.GFlowNetAgent`,
computing metrics plotting figures and optionally logging results using the
:class:`~gflownet.gflownet.GFlowNetAgent`'s :class:`~gflownet.utils.logger.Logger`.
Take this :class:`BaseEvaluator` as example to implement your own evaluator class
for your custom use-case.
.. important::
Prefer the :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.from_dir`
and :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.from_agent`
class methods to instantiate an evaluator.
See :ref:`using an evaluator` for more details about how to use an Evaluator.
"""
import copy
import pickle
import time
from collections import defaultdict
import numpy as np
import torch
from scipy.special import logsumexp
from gflownet.evaluator.abstract import ALL_REQS # noqa
from gflownet.evaluator.abstract import METRICS # noqa
from gflownet.evaluator.abstract import AbstractEvaluator
from gflownet.utils.batch import Batch
from gflownet.utils.common import batch_with_rest, tfloat, torch2np
[docs]
class BaseEvaluator(AbstractEvaluator):
def __init__(self, gfn_agent=None, **config):
"""
Base evaluator class for GFlowNetAgent.
In particular, implements the :meth:`eval` with:
- :meth:`compute_log_prob_metrics` to compute log-probability metrics.
- :meth:`compute_density_metrics` to compute density metrics.
And the :meth:`plot` method with:
- The :class:`~gflownet.envs.base.GFlowNetEnv`'s :meth:`plot_reward_samples`
method.
- The :class:`~gflownet.envs.base.GFlowNetEnv`'s :meth:`plot_kde` method if
it exists, for both the ``kde_pred`` and ``kde_true`` arguments if they are
returned in the ``"data"`` dict of the :meth:`eval` method.
See the :class:`~gflownet.evaluator.abstract.AbstractEvaluator` for more
details about other methods and attributes, including the
:meth:`~gflownet.evaluator.abstract.AbstractEvaluator.__init__`.
"""
super().__init__(gfn_agent, **config)
[docs]
def define_new_metrics(self):
return {
"l1": {
"display_name": "L1 error",
"requirements": ["density"],
},
"kl": {
"display_name": "KL Div.",
"requirements": ["density"],
},
"jsd": {
"display_name": "Jensen Shannon Div.",
"requirements": ["density"],
},
"corr_probs_rewards": {
"display_name": "Corr. (test probs., rewards)",
"requirements": ["log_probs", "reward_batch"],
},
"corr_logprobs_logrewards": {
"display_name": "Corr. (test logprobs., logrewards)",
"requirements": ["log_probs", "reward_batch"],
},
"var_logrewards_logp": {
"display_name": "Var(logR - logp) test",
"requirements": ["log_probs", "reward_batch"],
},
"nll_tt": {
"display_name": "NLL of test data",
"requirements": ["log_probs"],
},
"mean_logprobs_std": {
"display_name": "Mean BS Std(logp)",
"requirements": ["log_probs"],
},
"mean_probs_std": {
"display_name": "Mean BS Std(p)",
"requirements": ["log_probs"],
},
"logprobs_std_nll_ratio": {
"display_name": "BS Std(logp) / NLL",
"requirements": ["log_probs"],
},
}
# TODO: this method will most likely crash if used (top_k_period != -1) because
# self.gfn.env.top_k_metrics_and_plots still makes use of env.proxy.
# Re-implementing this wil require a non-trivial amount of work.
@torch.no_grad()
[docs]
def eval_top_k(self, it, gfn_states=None, random_states=None):
"""
Sample from the current GFN and compute metrics and plots for the top k states
according to both the energy and the reward.
Parameters
----------
it : int
current iteration
gfn_states : list, optional
Already sampled gfn states. Defaults to None.
random_states : list, optional
Already sampled random states. Defaults to None.
Returns
-------
dict
Computed dict of metrics, and figures, and optionally (only once) summary
metrics. Schema: ``{"metrics": {str: float}, "figs": {str: plt.Figure},
"summary": {str: float}}``.
"""
# only do random top k plots & metrics once
do_random = it // self.logger.test.top_k_period == 1
duration = None
summary = {}
# TODO: Why deepcopy?
prob = copy.deepcopy(self.random_action_prob)
print()
if not gfn_states:
# sample states from the current gfn
batch = Batch(
env=self.gfn.env,
proxy=self.gfn.proxy,
device=self.gfn.device,
float_type=self.gfn.float,
)
self.gfn.random_action_prob = 0
t = time.time()
print("Sampling from GFN...", end="\r")
for b in batch_with_rest(
0, self.gfn.logger.test.n_top_k, self.gfn.batch_size_total
):
sub_batch, _ = self.gfn.sample_batch(n_forward=len(b), train=False)
batch.merge(sub_batch)
duration = time.time() - t
gfn_states = batch.get_terminating_states()
# compute metrics and get plots
print("[eval_top_k] Making GFN plots...", end="\r")
metrics, figs, fig_names = self.gfn.env.top_k_metrics_and_plots(
gfn_states, self.gfn.logger.test.top_k, name="gflownet", step=it
)
if duration:
metrics["gflownet top k sampling duration"] = duration
if do_random:
# sample random states from uniform actions
if not random_states:
batch = Batch(
env=self.gfn.env,
proxy=self.gfn.proxy,
device=self.gfn.device,
float_type=self.gfn.float,
)
self.gfn.random_action_prob = 1.0
print("[eval_top_k] Sampling at random...", end="\r")
for b in batch_with_rest(
0, self.gfn.logger.test.n_top_k, self.gfn.batch_size_total
):
sub_batch, _ = self.gfn.sample_batch(n_forward=len(b), train=False)
batch.merge(sub_batch)
# compute metrics and get plots
random_states = batch.get_terminating_states()
print("[eval_top_k] Making Random plots...", end="\r")
(
random_metrics,
random_figs,
random_fig_names,
) = self.gfn.env.top_k_metrics_and_plots(
random_states, self.gfn.logger.test.top_k, name="random", step=None
)
# add to current metrics and plots
summary.update(random_metrics)
figs += random_figs
fig_names += random_fig_names
# compute training data metrics and get plots
print("[eval_top_k] Making train plots...", end="\r")
(
train_metrics,
train_figs,
train_fig_names,
) = self.gfn.env.top_k_metrics_and_plots(
None, self.gfn.logger.test.top_k, name="train", step=None
)
# add to current metrics and plots
summary.update(train_metrics)
figs += train_figs
fig_names += train_fig_names
self.gfn.random_action_prob = prob
print(" " * 100, end="\r")
print("eval_top_k metrics:")
max_k = max([len(k) for k in (list(metrics.keys()) + list(summary.keys()))]) + 1
print(
" • "
+ "\n • ".join(
f"{k:{max_k}}: {v:.4f}"
for k, v in (list(metrics.items()) + list(summary.items()))
)
)
print()
figs = {f: n for f, n in zip(figs, fig_names)}
return {
"metrics": metrics,
"figs": figs,
"summary": summary,
}
@torch.no_grad()
[docs]
def compute_log_prob_metrics(self, x_tt, metrics=None):
"""
Compute log-probability metrics for the given test data.
Uses :meth:`~gflownet.gflownet.GFlowNetAgent.estimate_logprobs_data`.
Known metrics:
- ``mean_logprobs_std``: Mean of the standard deviation of the log-probabilities.
- ``mean_probs_std``: Mean of the standard deviation of the probabilities.
- ``corr_probs_rewards``: Correlation between the probabilities and the
rewards.
- ``corr_logprobs_logrewards``: Correlation between the log-probabilities and
the log-rewards.
- ``var_logrewards_logp``: Variance of the log-rewards minus the log-probabilities.
- ``nll_tt``: Negative log-likelihood of the test data.
- ``logprobs_std_nll_ratio``: Ratio of the mean of the standard deviation of the
log-probabilities over the negative log-likelihood of the test data.
Returned data in the ``"data"`` sub-dict:
- ``probs``: Probabilities of the test data.
- ``rewards``: Rewards for the test data.
- ``logprobs``: Log-probabilities of the test data.
- ``logrewards``: Log-rewards of the test data.
Parameters
----------
x_tt : torch.Tensor
Test data.
metrics : List[str], optional
List of metrics to compute, by default ``None`` i.e. the evaluator's
``self.metrics``
Returns
-------
dict
Computed dict of metrics and data as
``{"metrics": {str: float}, "data": {str: object}}``.
"""
metrics = self.make_metrics(metrics)
reqs = self.make_requirements(metrics=metrics)
logprobs_x_tt, logprobs_std, probs_std = self.gfn.estimate_logprobs_data(
x_tt,
n_trajectories=self.config.n_trajs_logprobs,
max_data_size=self.config.max_data_logprobs,
batch_size=self.config.logprobs_batch_size,
bs_num_samples=self.config.logprobs_bootstrap_size,
)
lp_metrics = {}
lp_data = {}
if "mean_logprobs_std" in metrics:
lp_metrics["mean_logprobs_std"] = logprobs_std.mean().item()
if "mean_probs_std" in metrics:
lp_metrics["mean_probs_std"] = probs_std.mean().item()
if "reward_batch" in reqs:
rewards_x_tt = self.gfn.proxy.rewards(self.gfn.env.states2proxy(x_tt))
logrewards_x_tt = torch.log(rewards_x_tt)
lp_data["rewards"] = rewards_x_tt
lp_data["logrewards"] = logrewards_x_tt
if "corr_probs_rewards" in metrics:
probs_x_tt = np.exp(logprobs_x_tt.cpu().numpy())
lp_metrics["corr_probs_rewards"] = np.corrcoef(
probs_x_tt, rewards_x_tt
)[0, 1]
lp_metrics["corr_logprobs_logrewards"] = np.corrcoef(
logprobs_x_tt, logrewards_x_tt
)[0, 1]
lp_data["probs"] = probs_x_tt
lp_data["logprobs"] = logprobs_x_tt
if "var_logrewards_logp" in metrics:
lp_metrics["var_logrewards_logp"] = torch.var(
torch.log(
tfloat(
rewards_x_tt,
float_type=self.gfn.float,
device=self.gfn.device,
)
)
- logprobs_x_tt
).item()
if "nll_tt" in metrics:
lp_metrics["nll_tt"] = -logprobs_x_tt.mean().item()
if "logprobs_std_nll_ratio" in metrics:
lp_metrics["logprobs_std_nll_ratio"] = (
-logprobs_std.mean() / logprobs_x_tt.mean()
).item()
return {
"metrics": lp_metrics,
"data": lp_data,
}
[docs]
def compute_density_metrics(self, x_tt, dict_tt, metrics=None):
"""
Compute density metrics for the given test data.
Known metrics:
- ``l1``: L1 error between the true and predicted densities.
- ``kl``: KL divergence between the true and predicted densities.
- ``jsd``: Jensen-Shannon divergence between the true and predicted densities.
Returned data in the ``"data"`` sub-dict:
- ``x_sampled``: Sampled states from the GFN.
- ``kde_pred``: KDE policy as per
:meth:`~gflownet.envs.base.GFlowNetEnv.fit_kde`.
- ``kde_true``: True KDE.
Parameters
----------
x_tt : torch.Tensor
Test data.
dict_tt : dict
Dictionary of test data.
metrics : List[str], optional
List of metrics to compute, by default ``None`` i.e. the evaluator's
``self.metrics``
Returns
-------
dict
Computed dict of metrics and data as
``{"metrics": {str: float}, "data": {str: object}}``.
"""
metrics = self.make_metrics(metrics)
density_metrics = {}
density_data = {}
x_sampled = density_true = density_pred = None
if self.gfn.buffer.test_config.type == "all":
batch, _ = self.gfn.sample_batch(n_forward=self.config.n, train=False)
assert batch.is_valid()
x_sampled = batch.get_terminating_states()
if "density_true" in dict_tt:
density_true = torch2np(dict_tt["density_true"])
else:
rewards = torch2np(
self.gfn.proxy.rewards(self.gfn.env.states2proxy(x_tt))
)
z_true = rewards.sum()
density_true = rewards / z_true
with open(self.gfn.buffer.test_config.pkl, "wb") as f:
dict_tt["density_true"] = density_true
pickle.dump(dict_tt, f)
hist = defaultdict(int)
for x in x_sampled:
hist[tuple(x)] += 1
z_pred = sum([hist[tuple(x)] for x in x_tt]) + 1e-9
density_pred = np.array([hist[tuple(x)] / z_pred for x in x_tt])
log_density_true = np.log(density_true + 1e-8)
log_density_pred = np.log(density_pred + 1e-8)
elif self.gfn.continuous and hasattr(self.gfn.env, "fit_kde"):
batch, _ = self.gfn.sample_batch(n_forward=self.config.n, train=False)
assert batch.is_valid()
x_sampled = batch.get_terminating_states(proxy=True)
# TODO make it work with conditional env
x_tt = torch2np(self.gfn.env.states2proxy(x_tt))
kde_pred = self.gfn.env.fit_kde(
x_sampled,
kernel=self.config.kde.kernel,
bandwidth=self.config.kde.bandwidth,
)
if "log_density_true" in dict_tt and "kde_true" in dict_tt:
log_density_true = dict_tt["log_density_true"]
kde_true = dict_tt["kde_true"]
else:
# Sample from reward via rejection sampling
x_from_reward = self.gfn.env.states2proxy(
self.gfn.sample_from_reward(n_samples=self.config.n)
)
# Fit KDE with samples from reward
kde_true = self.gfn.env.fit_kde(
x_from_reward,
kernel=self.config.kde.kernel,
bandwidth=self.config.kde.bandwidth,
)
# Estimate true log density using test samples
# TODO: this may be specific-ish for the torus or not
scores_true = kde_true.score_samples(x_tt)
log_density_true = scores_true - logsumexp(scores_true, axis=0)
# Add log_density_true and kde_true to pickled test dict
with open(self.gfn.buffer.test_config.pkl, "wb") as f:
dict_tt["log_density_true"] = log_density_true
dict_tt["kde_true"] = kde_true
pickle.dump(dict_tt, f)
# Estimate pred log density using test samples
# TODO: this may be specific-ish for the torus or not
scores_pred = kde_pred.score_samples(x_tt)
log_density_pred = scores_pred - logsumexp(scores_pred, axis=0)
density_true = np.exp(log_density_true)
density_pred = np.exp(log_density_pred)
density_data["kde_pred"] = kde_pred
density_data["kde_true"] = kde_true
else:
density_metrics["l1"] = self.gfn.l1
density_metrics["kl"] = self.gfn.kl
density_metrics["jsd"] = self.gfn.jsd
density_data["x_sampled"] = x_sampled
return {
"metrics": density_metrics,
"data": density_data,
}
# L1 error
density_metrics["l1"] = np.abs(density_pred - density_true).mean()
# KL divergence
density_metrics["kl"] = (
density_true * (log_density_true - log_density_pred)
).mean()
# Jensen-Shannon divergence
log_mean_dens = np.logaddexp(log_density_true, log_density_pred) + np.log(0.5)
density_metrics["jsd"] = 0.5 * np.sum(
density_true * (log_density_true - log_mean_dens)
)
density_metrics["jsd"] += 0.5 * np.sum(
density_pred * (log_density_pred - log_mean_dens)
)
density_data["x_sampled"] = x_sampled
return {
"metrics": density_metrics,
"data": density_data,
}
@torch.no_grad()
[docs]
def eval(self, metrics=None, **plot_kwargs):
"""
Evaluate the GFlowNetAgent and compute metrics and plots.
If `metrics` is not provided, the evaluator's `self.metrics` attribute is used
(default).
Extand in subclasses to add more metrics and plots:
.. code-block:: python
def eval(self, metrics=None, **plot_kwargs):
result = super().eval(metrics=metrics, **plot_kwargs)
result["metrics"]["my_custom_metric"] = my_custom_metric_function()
result["figs"]["My custom plot"] = my_custom_plot_function()
return result
Parameters
----------
metrics : List[str], optional
List of metrics to compute, by default the evaluator's `self.metrics`
attribute.
plot_kwargs : dict, optional
Additional keyword arguments to pass to the plotting methods.
Returns
-------
dict
Computed dict of metrics and figures as
`{"metrics": {str: float}, "figs": {str: plt.Figure}}`.
"""
metrics = self.make_metrics(metrics)
reqs = self.make_requirements(metrics=metrics)
if self.gfn.buffer.test is None:
return {
"metrics": {
k: getattr(self.gfn, k) if hasattr(self.gfn, k) else None
for k in metrics
},
"data": {},
}
with open(self.gfn.buffer.test_config.pkl, "rb") as f:
dict_tt = pickle.load(f)
x_tt = self.gfn.buffer.test.samples.values.tolist()
all_data = {}
all_metrics = {}
# Compute correlation between the rewards of the test data and the log
# likelihood of the data according the the GFlowNet policy; and NLL.
# TODO: organise code for better efficiency and readability
if "log_probs" in reqs:
lp_results = self.compute_log_prob_metrics(x_tt, metrics=metrics)
all_metrics.update(lp_results.get("metrics", {}))
all_data.update(lp_results.get("data", {}))
if "density" in reqs:
density_results = self.compute_density_metrics(
x_tt, dict_tt, metrics=metrics
)
all_metrics.update(density_results.get("metrics", {}))
all_data.update(density_results.get("data", {}))
return {
"metrics": all_metrics,
"data": all_data,
}
[docs]
def plot(
self, x_sampled=None, kde_pred=None, kde_true=None, plot_kwargs={}, **kwargs
):
"""
Plots this evaluator should do, returned as a dict `{str: plt.Figure}` which
will be logged.
By default, this method will call the following methods of the GFlowNetAgent's
environment if they exist:
- `plot_reward_samples`
- `plot_kde` (for both the `kde_pred` and `kde_true` arguments)
- `plot_samples_topk`
Extend this method to add more plots:
.. code-block:: python
def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs):
figs = super().plot(x_sampled, kde_pred, kde_true, plot_kwargs)
figs["My custom plot"] = my_custom_plot_function(x_sampled, kde_pred)
return figs
Parameters
----------
x_sampled : list, optional
List of sampled states.
kde_pred : sklearn.neighbors.KernelDensity
KDE policy as per `Environment.fit_kde`
kde_true : object
True KDE.
plot_kwargs : dict
Additional keyword arguments to pass to the plotting methods.
kwargs : dict
Catch-all for additional arguments.
Returns
-------
dict[str, plt.Figure]
Dictionary of figures to be logged. The keys are the figure names and the
values are the figures.
"""
probs = kwargs.get("probs", None)
rewards = kwargs.get("rewards", None)
logprobs = kwargs.get("logprobs", None)
logrewards = kwargs.get("logrewards", None)
fig_kde_pred = fig_kde_true = fig_reward_samples = fig_samples_topk = (
fig_scatter_rewards_probs
) = None
if hasattr(self.gfn.env, "plot_reward_samples") and x_sampled is not None:
(
sample_space_batch,
rewards_sample_space,
) = self.gfn.get_sample_space_and_reward()
fig_reward_samples = self.gfn.env.plot_reward_samples(
x_sampled,
sample_space_batch,
rewards_sample_space,
**plot_kwargs,
)
if hasattr(self.gfn.env, "plot_kde"):
sample_space_batch, _ = self.gfn.get_sample_space_and_reward()
if kde_pred is not None:
fig_kde_pred = self.gfn.env.plot_kde(
sample_space_batch, kde_pred, **plot_kwargs
)
if kde_true is not None:
fig_kde_true = self.gfn.env.plot_kde(
sample_space_batch, kde_true, **plot_kwargs
)
# TODO: consider moving this to eval_top_k once fixed
if hasattr(self.gfn.env, "plot_samples_topk"):
if x_sampled is None:
batch, _ = self.gfn.sample_batch(
n_forward=self.config.n_top_k, train=False
)
x_sampled = batch.get_terminating_states()
if rewards is None:
rewards = self.gfn.proxy.rewards(self.gfn.env.states2proxy(x_sampled))
fig_samples_topk = self.gfn.env.plot_samples_topk(
x_sampled,
rewards,
self.config.top_k,
**plot_kwargs,
)
# Plot (log)rewards vs (log)probs for test set
if (
probs is not None
and rewards is not None
and logprobs is not None
and logrewards is not None
):
import matplotlib.pyplot as plt
fig_scatter_rewards_probs, ax = plt.subplots(
nrows=1, ncols=2, figsize=(8, 4), dpi=150
)
ax[0].scatter(rewards, probs)
ax[0].set_xlabel(f"Rewards")
ax[0].set_ylabel(f"Probs")
ax[1].scatter(logrewards, logprobs)
ax[1].set_xlabel(f"Log-rewards")
ax[1].set_ylabel(f"Log-probs")
fig_scatter_rewards_probs.tight_layout()
return {
"True reward and GFlowNet samples": fig_reward_samples,
"GFlowNet KDE Policy": fig_kde_pred,
"Reward KDE": fig_kde_true,
"Samples TopK": fig_samples_topk,
"Scatterplot Rewards vs. Probs": fig_scatter_rewards_probs,
}