gflownet.evaluator.base ======================= .. py:module:: gflownet.evaluator.base .. autoapi-nested-parse:: Base evaluator class for a :class:`~gflownet.gflownet.GFlowNetAgent`. In charge of evaluating a generic :class:`~gflownet.gflownet.GFlowNetAgent`, computing metrics plotting figures and optionally logging results using the :class:`~gflownet.gflownet.GFlowNetAgent`'s :class:`~gflownet.utils.logger.Logger`. Take this :class:`BaseEvaluator` as example to implement your own evaluator class for your custom use-case. .. important:: Prefer the :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.from_dir` and :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.from_agent` class methods to instantiate an evaluator. See :ref:`using an evaluator` for more details about how to use an Evaluator. Classes ------- .. autoapisummary:: gflownet.evaluator.base.BaseEvaluator Module Contents --------------- .. py:class:: BaseEvaluator(gfn_agent=None, **config) Bases: :py:obj:`gflownet.evaluator.abstract.AbstractEvaluator` Base evaluator class for GFlowNetAgent. In particular, implements the :meth:`eval` with: - :meth:`compute_log_prob_metrics` to compute log-probability metrics. - :meth:`compute_density_metrics` to compute density metrics. And the :meth:`plot` method with: - The :class:`~gflownet.envs.base.GFlowNetEnv`'s :meth:`plot_reward_samples` method. - The :class:`~gflownet.envs.base.GFlowNetEnv`'s :meth:`plot_kde` method if it exists, for both the ``kde_pred`` and ``kde_true`` arguments if they are returned in the ``"data"`` dict of the :meth:`eval` method. See the :class:`~gflownet.evaluator.abstract.AbstractEvaluator` for more details about other methods and attributes, including the :meth:`~gflownet.evaluator.abstract.AbstractEvaluator.__init__`. .. py:method:: define_new_metrics() Method to be implemented by subclasses to define new metrics. .. admonition:: Example .. code-block:: python def define_new_metrics(self): return { "my_custom_metric": { "display_name": "My custom metric", "requirements": ["density", "new_req"], } } :returns: *dict* -- Dictionary of new metrics to add to the global :const:`METRICS` dict. .. py:method:: eval_top_k(it, gfn_states=None, random_states=None) Sample from the current GFN and compute metrics and plots for the top k states according to both the energy and the reward. :param it: current iteration :type it: int :param gfn_states: Already sampled gfn states. Defaults to None. :type gfn_states: list, optional :param random_states: Already sampled random states. Defaults to None. :type random_states: list, optional :returns: *dict* -- Computed dict of metrics, and figures, and optionally (only once) summary metrics. Schema: ``{"metrics": {str: float}, "figs": {str: plt.Figure}, "summary": {str: float}}``. .. py:method:: compute_log_prob_metrics(x_tt, metrics=None) Compute log-probability metrics for the given test data. Uses :meth:`~gflownet.gflownet.GFlowNetAgent.estimate_logprobs_data`. Known metrics: - ``mean_logprobs_std``: Mean of the standard deviation of the log-probabilities. - ``mean_probs_std``: Mean of the standard deviation of the probabilities. - ``corr_probs_rewards``: Correlation between the probabilities and the rewards. - ``corr_logprobs_logrewards``: Correlation between the log-probabilities and the log-rewards. - ``var_logrewards_logp``: Variance of the log-rewards minus the log-probabilities. - ``nll_tt``: Negative log-likelihood of the test data. - ``logprobs_std_nll_ratio``: Ratio of the mean of the standard deviation of the log-probabilities over the negative log-likelihood of the test data. Returned data in the ``"data"`` sub-dict: - ``probs``: Probabilities of the test data. - ``rewards``: Rewards for the test data. - ``logprobs``: Log-probabilities of the test data. - ``logrewards``: Log-rewards of the test data. :param x_tt: Test data. :type x_tt: torch.Tensor :param metrics: List of metrics to compute, by default ``None`` i.e. the evaluator's ``self.metrics`` :type metrics: List[str], optional :returns: *dict* -- Computed dict of metrics and data as ``{"metrics": {str: float}, "data": {str: object}}``. .. py:method:: compute_density_metrics(x_tt, dict_tt, metrics=None) Compute density metrics for the given test data. Known metrics: - ``l1``: L1 error between the true and predicted densities. - ``kl``: KL divergence between the true and predicted densities. - ``jsd``: Jensen-Shannon divergence between the true and predicted densities. Returned data in the ``"data"`` sub-dict: - ``x_sampled``: Sampled states from the GFN. - ``kde_pred``: KDE policy as per :meth:`~gflownet.envs.base.GFlowNetEnv.fit_kde`. - ``kde_true``: True KDE. :param x_tt: Test data. :type x_tt: torch.Tensor :param dict_tt: Dictionary of test data. :type dict_tt: dict :param metrics: List of metrics to compute, by default ``None`` i.e. the evaluator's ``self.metrics`` :type metrics: List[str], optional :returns: *dict* -- Computed dict of metrics and data as ``{"metrics": {str: float}, "data": {str: object}}``. .. py:method:: eval(metrics=None, **plot_kwargs) Evaluate the GFlowNetAgent and compute metrics and plots. If `metrics` is not provided, the evaluator's `self.metrics` attribute is used (default). Extand in subclasses to add more metrics and plots: .. code-block:: python def eval(self, metrics=None, **plot_kwargs): result = super().eval(metrics=metrics, **plot_kwargs) result["metrics"]["my_custom_metric"] = my_custom_metric_function() result["figs"]["My custom plot"] = my_custom_plot_function() return result :param metrics: List of metrics to compute, by default the evaluator's `self.metrics` attribute. :type metrics: List[str], optional :param plot_kwargs: Additional keyword arguments to pass to the plotting methods. :type plot_kwargs: dict, optional :returns: *dict* -- Computed dict of metrics and figures as `{"metrics": {str: float}, "figs": {str: plt.Figure}}`. .. py:method:: plot(x_sampled=None, kde_pred=None, kde_true=None, plot_kwargs={}, **kwargs) Plots this evaluator should do, returned as a dict `{str: plt.Figure}` which will be logged. By default, this method will call the following methods of the GFlowNetAgent's environment if they exist: - `plot_reward_samples` - `plot_kde` (for both the `kde_pred` and `kde_true` arguments) - `plot_samples_topk` Extend this method to add more plots: .. code-block:: python def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs): figs = super().plot(x_sampled, kde_pred, kde_true, plot_kwargs) figs["My custom plot"] = my_custom_plot_function(x_sampled, kde_pred) return figs :param x_sampled: List of sampled states. :type x_sampled: list, optional :param kde_pred: KDE policy as per `Environment.fit_kde` :type kde_pred: sklearn.neighbors.KernelDensity :param kde_true: True KDE. :type kde_true: object :param plot_kwargs: Additional keyword arguments to pass to the plotting methods. :type plot_kwargs: dict :param kwargs: Catch-all for additional arguments. :type kwargs: dict :returns: *dict[str, plt.Figure]* -- Dictionary of figures to be logged. The keys are the figure names and the values are the figures.