gflownet.utils.logger

Classes

Logger

Utils functions to compute and handle the statistics (saving them or send to

Module Contents

class gflownet.utils.logger.Logger(config, do, project_name, logdir, lightweight, debug, run_name=None, run_name_date=True, run_name_job=True, run_id=None, tags=None, context='0', notes=None, entity=None, progressbar={'skip': False, 'n_iters_mean': 100}, is_resumed=False)[source]

Utils functions to compute and handle the statistics (saving them or send to wandb). It can be passed on to querier, gfn, proxy, … to get the statistics of training of the generated data at real time

Parameters:
  • run_name (str) – Name of the run. By default it is None. If run_name is None and run_name_date and run_name_job are both False, then a random name will be assigned by wandb.

  • run_name_date (bool) – Whether the date (and time) should be included in the run name. True by default.

  • run_name_job (bool) – Whether the job ID should be included in the run name. True by default.

  • progressbar (dict) –

    A dictionary of configuration parameters related to the progress bar, namely:
    • skipbool

      If True, the progress bar is not displayed during training. False by default.

    • n_iters_meanint

      The number of past iterations to take into account to compute averages of a metric, for example the loss. 100 by default.

  • config (dict)

  • do (dict)

  • project_name (str)

  • logdir (dict)

  • lightweight (bool)

  • debug (bool)

  • run_id (str)

  • tags (list)

  • context (str)

  • notes (str)

  • entity (str)

  • is_resumed (bool)

config[source]
do[source]
context = '0'[source]
progressbar[source]
loss_memory = [][source]
lightweight[source]
debug[source]
is_resumed = False[source]
ckpts_dir[source]
datadir[source]
write_url_file()[source]
add_tags(tags)[source]
Parameters:

tags (list)

set_context(context)[source]
Parameters:

context (int)

progressbar_update(pbar, loss, rewards, jsd, use_context=True, n_mean=100)[source]
log_histogram(key, value, step, use_context=True)[source]
log_plots(figs, step, use_context=True)[source]
Parameters:

figs (Union[dict, list])

close_figs(figs)[source]
Parameters:

figs (list)

log_rewards_and_scores(rewards, logrewards, scores, step, prefix, use_context=True)[source]

Logs the rewards, log-rewards and proxy scores passed as arguments.

Parameters:
  • rewards (tensor) – Rewards of a batch of states.

  • logrewards (tensor) – Log-rewards of a batch of states.

  • scores (tensor) – Proxy scores of a batch of states.

  • step (int) – The training iteration number.

  • prefix (str) – Prefix to be added to the metric names.

  • use_context (bool) – If True, prepend self.context + / to the key of the metric.

log_metrics(metrics, step, use_context=True)[source]

Logs metrics to wandb.

Parameters:
  • metrics (dict) – A dictionary of metrics to be logged to wandb.

  • step (int) – The training iteration number.

  • use_context (bool) – If True, prepend self.context + / to the key of the metric.

log_summary(summary)[source]
Parameters:

summary (dict)

save_checkpoint(forward_policy, backward_policy, state_flow, logZ, optimizer, buffer, step, final=False)[source]
Parameters:
  • step (int)

  • final (bool)

log_time(times, use_context)[source]
Parameters:
  • times (dict)

  • use_context (bool)

end()[source]