Source code for gflownet.envs.composite.utils

from typing import Iterable

from gflownet.envs.base import GFlowNetEnv
from gflownet.envs.composite.setfix import SetFix
from gflownet.envs.composite.setflex import SetFlex


[docs] def make_set( is_flexible: bool, subenvs: Iterable[GFlowNetEnv] = None, max_elements: int = None, envs_unique: Iterable[GFlowNetEnv] = None, do_random_subenvs: bool = False, **kwargs, ): """ Factory method to create SetFix or SetFlex classes depending on the input is_flexible. This method mimics conditional inheritance. Parameters ---------- is_flexible : bool If True, return a SetFlex environment. If False, return a SetFix environment. """ # If is_flexible is False, subenvs must be defined. if not is_flexible and subenvs is None: raise ValueError("subenvs must be defined to use the SetFix") # If is_flexible is True, then max_elements must be defined. if is_flexible and max_elements is None: raise ValueError("max_elements must be defined to use the SetFlex") if is_flexible: return SetFlex( max_elements=max_elements, envs_unique=envs_unique, subenvs=subenvs, do_random_subenvs=do_random_subenvs, **kwargs, ) else: return SetFix(subenvs=subenvs, **kwargs)