gflownet.proxy.tetris ===================== .. py:module:: gflownet.proxy.tetris Classes ------- .. autoapisummary:: gflownet.proxy.tetris.Tetris Module Contents --------------- .. py:class:: Tetris(normalize, **kwargs) Bases: :py:obj:`gflownet.proxy.base.Proxy` Base Proxy class for GFlowNet proxies. A proxy is the input to a reward function. Depending on the ``reward_function``, the reward may be directly the output of the proxy or a function of it. :param device: The device to be passed to torch tensors. :type device: str or torch.device :param float_precision: The floating point precision to be passed to torch tensors. :type float_precision: int or torch.dtype :param reward_function: The transformation applied to the proxy outputs to obtain a GFlowNet reward. See :py:meth:`Proxy._get_reward_functions`. :type reward_function: str or Callable :param logreward_function: The transformation applied to the proxy outputs to obtain a GFlowNet log reward. See :meth:`Proxy._get_reward_functions`. If None (default), the log of the reward function is used. The Callable may be used to improve the numerical stability of the transformation. :type logreward_function: Callable :param reward_function_kwargs: A dictionary of arguments to be passed to the reward function. :type reward_function_kwargs: dict :param reward_min: The minimum value allowed for rewards, 0.0 by default, which results in a minimum log reward of :py:const:`LOGZERO`. Note that certain loss functions, for example the Forward Looking loss may not work as desired if the minimum reward is 0.0. It may be set to a small (positive) value close to zero in order to prevent numerical stability issues. :type reward_min: float :param do_clip_rewards: Whether to clip the rewards according to the minimum value. :type do_clip_rewards: bool .. py:attribute:: normalize .. py:method:: setup(env=None) .. py:property:: norm .. py:method:: __call__(states) Implement this function to call the get_reward method of the appropriate Proxy Class (EI, UCB, Proxy, Oracle etc). :param states: :type states: ndarray