Source code for gflownet.evaluator.abstract

"""
Abstract evaluator class for GFlowNetAgent.

.. warning::

    Should not be used directly, but subclassed to implement specific evaluators for
    different tasks and environments.

See :class:`~gflownet.evaluator.base.BaseEvaluator` for a default,
concrete implementation of this abstract class.

This class handles some logic that will be the same for all evaluators.
The only requirements for a subclass are to implement the
:meth:`~gflownet.evaluator.abstract.AbstractEvaluator.eval` and
:meth:`~gflownet.evaluator.abstract.AbstractEvaluator.plot` methods
which will be called by the
:meth:`~gflownet.evaluator.abstract.AbstractEvaluator.eval_and_log` method:

.. code-include :: :meth:`gflownet.evaluator.abstract.AbstractEvaluator.eval_and_log`

.. code-include :: :func:`gflownet.evaluator.abstract.AbstractEvaluator.eval_and_log`

.. code-include :: :class:`gflownet.gflownet.abstract.AbstractEvaluator`

.. code-include :: :func:`gflownet.utils.common.gflownet_from_config`

.. code-block:: python

        def eval_and_log(self, it, metrics=None):
            results = self.eval(metrics=metrics)
            for m, v in results["metrics"].items():
                setattr(self.gfn, m, v)

            metrics_to_log = {
                METRICS[k]["display_name"]: v for k, v in results["metrics"].items()
            }

            figs = self.plot(**results["data"])

            self.logger.log_metrics(metrics_to_log, it, self.gfn.use_context)
            self.logger.log_plots(figs, it, use_context=self.gfn.use_context)

See :mod:`gflownet.evaluator` for a full-fledged example and
:mod:`gflownet.evaluator.base` for a concrete implementation of this abstract class.
"""

import os
from abc import ABCMeta, abstractmethod
from typing import Union

from omegaconf import OmegaConf

from gflownet.utils.common import load_gflownet_from_rundir

# purposefully non-documented object, hidden from Sphinx docs
_sentinel = object()
"""
A sentinel object to be used as a default value for arguments that could be None.
"""

[docs] METRICS = {}
""" All metrics that can be computed by a ``BaseEvaluator``. Structured as a dict with the metric names as keys and the metric display names and requirements as values. Requirements are used to decide which kind of data / samples is required to compute the metric. Display names are used to log the metrics and to display them in the console. Implementations of :class:`AbstractEvaluator` can add new metrics to this dict by implementing the method :meth:`AbstractEvaluator.define_new_metrics`. """
[docs] ALL_REQS = set([r for m in METRICS.values() for r in m["requirements"]])
""" Union of all requirements of all metrics in :const:`METRICS`. """
[docs] class AbstractEvaluator(metaclass=ABCMeta): def __init__(self, gfn_agent=None, **config): """ Abstract evaluator class for :class:`GFlowNetAgent`. In charge of evaluating the :class:`GFlowNetAgent`, computing metrics plotting figures and optionally logging results using the :class:`GFlowNetAgent`'s :class:`Logger`. You can use the :meth:`from_dir` or :meth:`from_agent` class methods to easily instantiate this class from a run directory or an existing in-memory :class:`GFlowNetAgent`. Use :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.set_agent` to set the evaluator's :class:`GFlowNetAgent` after initialization if it was not provided at instantiation as ``GflowNetEvaluator(gfn_agent=...)``. This ``__init__`` function will call, in order: 1. :meth:`update_all_metrics_and_requirements` which uses new metrics defined in the :meth:`define_new_metrics` method to update the global :const:`METRICS` and :const:`ALL_REQS` variables in classes inheriting from :class:`AbstractEvaluator`. 2. ``self.metrics = self.make_metrics(self.config.metrics)`` using :meth:`make_metrics` 3. ``self.reqs = self.make_requirements()`` using :meth:`make_requirements` Arguments --------- gfn_agent : GFlowNetAgent, optional The GFlowNetAgent to evaluate. By default None. Should be set using the :meth:`from_dir` or :meth:`from_agent` class methods. config : dict The configuration of the evaluator. Will be converted to an OmegaConf instance and stored in the ``self.config`` attribute. Attributes ---------- config : :class:`omegaconf.OmegaConf` The configuration of the evaluator. metrics : dict Dictionary of metrics to compute, with the metric names as keys and the metric display names and requirements as values. reqs : set[str] The set of requirements for the metrics. Used to decide which kind of data / samples is required to compute the metric. logger : Logger The logger to use to log the results of the evaluation. Will be set to the GFlowNetAgent's logger. """ self._gfn_agent = gfn_agent
[docs] self.config = OmegaConf.create(config)
if self._gfn_agent is not None: self.logger = self._gfn_agent.logger self.metrics = self.reqs = _sentinel self.update_all_metrics_and_requirements()
[docs] self.metrics = self.make_metrics(self.config.metrics)
[docs] self.reqs = self.make_requirements()
@property
[docs] def gfn(self): """ Get the ``GFlowNetAgent`` to evaluate. This is a read-only property. Use the :meth:`set_agent` method to set the ``GFlowNetAgent``. Returns ------- :class:`GFlowNetAgent` The ``GFlowNetAgent`` to evaluate. Raises ------ ValueError If the ``GFlowNetAgent`` has not been set. """ if type(self._gfn_agent).__name__ != "GFlowNetAgent": raise ValueError( "The GFlowNetAgent has not been set. Use the `from_dir` or `from_agent`" + " class methods to instantiate this class or the `set_agent` method" ) return self._gfn_agent
[docs] def set_agent(self, gfn_agent): """ Set the ``GFlowNetAgent`` to evaluate after initialization. It is then accessible through the ``self.gfn`` property. Parameters ---------- gfn_agent : :class:`GFlowNetAgent` The ``GFlowNetAgent`` to evaluate. """ assert type(gfn_agent).__name__ == "GFlowNetAgent", ( "gfn_agent should be an instance of GFlowNetAgent, but is an instance of " + f"{type(gfn_agent)}." ) self._gfn_agent = gfn_agent self.logger = gfn_agent.logger
@gfn.setter def gfn(self, _): raise AttributeError( "The `gfn` attribute is read-only. Use the `set_agent` method to set the" + " GFlowNetAgent." )
[docs] def define_new_metrics(self): """ Method to be implemented by subclasses to define new metrics. Example ------- .. code-block:: python def define_new_metrics(self): return { "my_custom_metric": { "display_name": "My custom metric", "requirements": ["density", "new_req"], } } Returns ------- dict Dictionary of new metrics to add to the global :const:`METRICS` dict. """ pass
[docs] def update_all_metrics_and_requirements(self): """ Method to be implemented by subclasses to update the global dict of metrics and requirements. """ new_metrics = self.define_new_metrics() if new_metrics: global METRICS global ALL_REQS METRICS.update(new_metrics) ALL_REQS = set([r for m in METRICS.values() for r in m["requirements"]])
@classmethod
[docs] def from_dir( cls: "AbstractEvaluator", path: Union[str, os.PathLike], no_wandb: bool = True, print_config: bool = False, device: str = "cuda", load_final_ckpt: bool = True, ): """ Instantiate a BaseEvaluator from a run directory. Parameters ---------- cls : BaseEvaluator Class to instantiate. path : Union[str, os.PathLike] Path to the run directory from which to load the GFlowNetAgent. no_wandb : bool, optional Prevent wandb initialization, by default True print_config : bool, optional Whether or not to print the resulting (loaded) config, by default False device : str, optional Device to use for the instantiated GFlowNetAgent, by default "cuda" load_final_ckpt : bool, optional Use the latest possible checkpoint available in the path, by default True Returns ------- BaseEvaluator Instance of BaseEvaluator with the GFlowNetAgent loaded from the run. """ gfn_agent, _ = load_gflownet_from_rundir( path, no_wandb=no_wandb, print_config=print_config, device=device, load_final_ckpt=load_final_ckpt, ) return cls.from_agent(gfn_agent)
@classmethod
[docs] def from_agent(cls, gfn_agent): """ Instantiate a BaseEvaluator from a GFlowNetAgent. Parameters ---------- cls : BaseEvaluator Evaluator class to instantiate. gfn_agent : GFlowNetAgent Instance of GFlowNetAgent to use for the BaseEvaluator. Returns ------- BaseEvaluator Instance of BaseEvaluator with the provided GFlowNetAgent. """ from gflownet.gflownet import GFlowNetAgent assert isinstance(gfn_agent, GFlowNetAgent), ( "gfn_agent should be an instance of GFlowNetAgent, but is an instance of " + f"{type(gfn_agent)}." ) return cls(gfn_agent=gfn_agent, **gfn_agent.evaluator.config)
[docs] def make_metrics(self, metrics=None): """ Parse metrics from a dict, list, a string or ``None``. - If ``None``, all metrics are selected. - If a string, it can be a comma-separated list of metric names, with or without spaces. - If a list, it should be a list of metric names (keys of :const:`METRICS`). - If a dict, its keys should be metric names and its values will be ignored: they will be assigned from :const:`METRICS`. All metrics must be in :const:`METRICS`. Parameters ---------- metrics : Union[str, List[str]], optional Metrics to compute when running the :meth:`.eval` method. Defaults to ``None``, i.e. all metrics in :const:`METRICS` are computed. Returns ------- dict Dictionary of metrics to compute, with the metric names as keys and the metric display names and requirements as values. Raises ------ ValueError If a metric name is not in :const:`METRICS`. """ if metrics is None: assert self.metrics is not _sentinel, ( "Error setting self.metrics. This is likely due to the `metrics:`" + " entry missing from your eval config. Set it to 'all' to compute all" + " metrics or to a comma-separated list of metric names (eg 'l1, kl')." ) return self.metrics if not isinstance(metrics, (str, list, dict)): raise ValueError( "metrics should be None, a string, a list or a dict," + f" but is {type(metrics)}." ) if metrics == "all": metrics = METRICS.keys() if isinstance(metrics, str): if metrics == "": raise ValueError( "`metrics` should not be an empty string. " + "Set to 'all' or a list of metric names or None (null in YAML)." ) if "," in metrics: metrics = metrics.split(",") else: metrics = [metrics] if isinstance(metrics, dict): metrics = metrics.keys() metrics = [m.strip() for m in metrics] for m in metrics: if m not in METRICS: raise ValueError(f"Unknown metric name: {m}") return {m: METRICS[m] for m in metrics}
[docs] def make_requirements(self, reqs=None, metrics=None): """ Make requirements for the metrics to compute. 1. If ``metrics`` is provided, they must be as a dict of metrics. The requirements are computed from the ``requirements`` attribute of the metrics. 2. Otherwise, the requirements are computed from the ``reqs`` argument: - If ``reqs`` is ``"all"``, all requirements of all metrics are computed. - If ``reqs`` is ``None``, the evaluator's ``self.reqs`` attribute is used. - If ``reqs`` is a list, it is used as the requirements. Parameters ---------- reqs : Union[str, List[str]], optional The metrics requirements. Either ``"all"``, a list of requirements or ``None`` to use the evaluator's ``self.reqs`` attribute. By default ``None``. metrics : Union[str, List[str], dict], optional The metrics to compute requirements for. If not a dict, will be passed to :meth:`make_metrics`. By default None. Returns ------- set[str] The set of requirements for the metrics. """ if metrics is not None: if not isinstance(metrics, dict): metrics = self.make_metrics(metrics) for m in metrics: if m not in METRICS: raise ValueError(f"Unknown metric name: {m}") return set([r for m in metrics.values() for r in m["requirements"]]) if isinstance(reqs, str): if reqs == "all": reqs = ALL_REQS.copy() else: raise ValueError( "reqs should be 'all', a list of requirements or None, but is " + f"{reqs}." ) if reqs is None: if self.reqs is _sentinel: if not isinstance(self.metrics, dict): raise ValueError( "Cannot compute requirements from `None` without the `metrics`" + " argument or the `self.metrics` attribute set to a dict" + " of metrics." ) self.reqs = set( [r for m in self.metrics.values() for r in m["requirements"]] ) reqs = self.reqs if isinstance(reqs, list): reqs = set(reqs) assert isinstance( reqs, set ), f"reqs should be None, 'all', a set or a list, but is {type(reqs)}" assert all([isinstance(r, str) for r in reqs]), ( "All elements of reqs should be strings, but are " + f"{[type(r) for r in reqs]}" ) for r in reqs: if r not in ALL_REQS: raise ValueError(f"Unknown requirement: {r}") return reqs
[docs] def should_log_train(self, step): """ Check if training logs should be done at the current step. The decision is based on the ``self.config.train.period`` attribute. Set ``self.config.train.period`` to ``None`` or a negative value to disable training. Parameters ---------- step : int Current iteration step. Returns ------- bool True if train logging should be done at the current step, False otherwise. """ if self.config.train_log_period is None or self.config.train_log_period <= 0: return False else: return step % self.config.train_log_period == 0
[docs] def should_eval(self, step): """ Check if testing should be done at the current step. The decision is based on the ``self.config.test.period`` attribute. Set ``self.config.test.first_it`` to ``True`` if testing should be done at the first iteration step. Otherwise, testing will be done aftter ``self.config.test.period`` steps. Set ``self.config.test.period`` to ``None`` or a negative value to disable testing. Parameters ---------- step : int Current iteration step. Returns ------- bool True if testing should be done at the current step, False otherwise. """ if self.config.period is None or self.config.period <= 0: return False elif step == 1 and self.config.first_it: return True else: return step % self.config.period == 0
[docs] def should_eval_top_k(self, step): """ Check if top k plots and metrics should be done at the current step. The decision is based on the ``self.config.test.top_k`` and ``self.config.test.top_k_period`` attributes. Set ``self.config.test.top_k`` to ``None`` or a negative value to disable top k plots and metrics. Parameters ---------- step : int Current iteration step. Returns ------- bool True if top k plots and metrics should be done at the current step, False """ if self.config.top_k is None or self.config.top_k <= 0: return False if self.config.top_k_period is None or self.config.top_k_period <= 0: return False if step == 1 and self.config.first_it: return True return step % self.config.top_k_period == 0
[docs] def should_checkpoint(self, step): """ Check if checkpoints should be done at the current step. The decision is based on the ``self.checkpoints.period`` attribute. Set ``self.checkpoints.period`` to ``None`` or a negative value to disable checkpoints. Parameters ---------- step : int Current iteration step. Returns ------- bool True if checkpoints should be done at the current step, False otherwise. """ if ( self.config.checkpoints_period is None or self.config.checkpoints_period <= 0 ): return False else: return not step % self.config.checkpoints_period
@abstractmethod
[docs] def plot(self, **kwargs): """ The main method to plot results. Will be called by the :meth:`eval_and_log` method to plot the results of the evaluation. Will be passed the results of the :meth:`eval` method: .. code-block:: python # in eval_and_log results = self.eval(metrics=metrics) figs = self.plot(**results["data"]) Returns ------- dict Dictionary of figures to log, with the figure names as keys and the figures as values. """ pass
@abstractmethod
[docs] def eval(self, metrics=None, **plot_kwargs): """ The main method to compute metrics and intermediate results. This method should return a dict with two keys: ``"metrics"`` and ``"data"``. The "metrics" key should contain the new metric(s) and the "data" key should contain the intermediate results that can be used to plot the new metric(s). Example ------- >>> metrics = None # use the default metrics from the config file >>> results = gfne.eval(metrics=metrics) >>> plots = gfne.plot(**results["data"]) >>> metrics = "all" # compute all metrics, regardless of the config >>> results = gfne.eval(metrics=metrics) >>> metrics = ["l1", "kl"] # compute only the L1 and KL metrics >>> results = gfne.eval(metrics=metrics) >>> metrics = "l1,kl" # alternative syntax >>> results = gfne.eval(metrics=metrics) See :ref:`evaluator basic concepts` for more details about ``metrics``. Parameters ---------- metrics : Union[str, dict, list], optional Which metrics to compute, by default ``None``. """ pass
@abstractmethod
[docs] def eval_top_k(self, it): """ Evaluate the ``GFlowNetAgent``'s top k samples performance. Classes extending this abstract class should implement this method. Parameters ---------- it : int Current iteration step. Returns ------- dict Dictionary with the following keys schema: .. code-block:: python { "metrics": {str: float}, "figs": {str: plt.Figure}, "summary": {str: float}, } """ pass
[docs] def eval_and_log(self, it, metrics=None): """ Evaluate the GFlowNetAgent and log the results with its logger. Will call ``self.eval()`` and log the results using the GFlowNetAgent's logger ``log_metrics()`` and ``log_plots()`` methods. Parameters ---------- it : int Current iteration step. metrics : Union[str, List[str]], optional List of metrics to compute, by default the evaluator's ``metrics`` attribute. """ results = self.eval(metrics=metrics) for m, v in results["metrics"].items(): setattr(self.gfn, m, v) metrics_to_log = { METRICS[k]["display_name"]: v for k, v in results["metrics"].items() } figs = self.plot(**results["data"]) self.logger.log_metrics(metrics_to_log, it, self.gfn.use_context) self.logger.log_plots(figs, it, use_context=self.gfn.use_context)
[docs] def eval_and_log_top_k(self, it): """ Evaluate the GFlowNetAgent's top k samples performance and log the results with its logger. Parameters ---------- it : int Current iteration step, by default None. """ results = self.eval_top_k(it) self.logger.log_plots(results["figs"], it, use_context=self.use_context) self.logger.log_metrics( results["metrics"], use_context=self.use_context, step=it ) self.logger.log_summary(results["summary"])