Source code for rl_zoo3.callbacks

import os
import tempfile
import time
from copy import deepcopy
from functools import wraps
from threading import Thread
from typing import Optional, Type, Union

import optuna
from sb3_contrib import TQC
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
from stable_baselines3.common.logger import TensorBoardOutputFormat
from stable_baselines3.common.vec_env import VecEnv

[docs]class TrialEvalCallback(EvalCallback): """ Callback used for evaluating and reporting a trial. """ def __init__( self, eval_env: VecEnv, trial: optuna.Trial, n_eval_episodes: int = 5, eval_freq: int = 10000, deterministic: bool = True, verbose: int = 0, best_model_save_path: Optional[str] = None, log_path: Optional[str] = None, ) -> None: super().__init__( eval_env=eval_env, n_eval_episodes=n_eval_episodes, eval_freq=eval_freq, deterministic=deterministic, verbose=verbose, best_model_save_path=best_model_save_path, log_path=log_path, ) self.trial = trial self.eval_idx = 0 self.is_pruned = False def _on_step(self) -> bool: if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: super()._on_step() self.eval_idx += 1 # report best or report current ? # report num_timesteps or elasped time ?, self.eval_idx) # Prune trial if need if self.trial.should_prune(): self.is_pruned = True return False return True
[docs]class SaveVecNormalizeCallback(BaseCallback): """ Callback for saving a VecNormalize wrapper every ``save_freq`` steps :param save_freq: (int) :param save_path: (str) Path to the folder where ``VecNormalize`` will be saved, as ``vecnormalize.pkl`` :param name_prefix: (str) Common prefix to the saved ``VecNormalize``, if None (default) only one file will be kept. """ def __init__(self, save_freq: int, save_path: str, name_prefix: Optional[str] = None, verbose: int = 0): super().__init__(verbose) self.save_freq = save_freq self.save_path = save_path self.name_prefix = name_prefix def _init_callback(self) -> None: # Create folder if needed if self.save_path is not None: os.makedirs(self.save_path, exist_ok=True) def _on_step(self) -> bool: # make mypy happy assert self.model is not None if self.n_calls % self.save_freq == 0: if self.name_prefix is not None: path = os.path.join(self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps.pkl") else: path = os.path.join(self.save_path, "vecnormalize.pkl") if self.model.get_vec_normalize_env() is not None: self.model.get_vec_normalize_env().save(path) # type: ignore[union-attr] if self.verbose > 1: print(f"Saving VecNormalize to {path}") return True
[docs]class ParallelTrainCallback(BaseCallback): """ Callback to explore (collect experience) and train (do gradient steps) at the same time using two separate threads. Normally used with off-policy algorithms and `train_freq=(1, "episode")`. TODO: - blocking mode: wait for the model to finish updating the policy before collecting new experience at the end of a rollout - force sync mode: stop training to update to the latest policy for collecting new experience :param gradient_steps: Number of gradient steps to do before sending the new policy :param verbose: Verbosity level :param sleep_time: Limit the fps in the thread collecting experience. """ def __init__(self, gradient_steps: int = 100, verbose: int = 0, sleep_time: float = 0.0): super().__init__(verbose) self.batch_size = 0 self._model_ready = True self._model: Union[SAC, TQC] self.gradient_steps = gradient_steps self.process: Thread self.model_class: Union[Type[SAC], Type[TQC]] self.sleep_time = sleep_time def _init_callback(self) -> None: temp_file = tempfile.TemporaryFile() # Windows TemporaryFile is not a io Buffer # we save the model in the logs/ folder if == "nt": temp_file = os.path.join("logs", "") # type: ignore[arg-type,assignment] # make mypy happy assert isinstance(self.model, (SAC, TQC)), f"{self.model} is not supported for parallel training" # type: ignore[arg-type] # TODO: add support for other algorithms for model_class in [SAC, TQC]: if isinstance(self.model, model_class): self.model_class = model_class # type: ignore[assignment] break assert self.model_class is not None, f"{self.model} is not supported for parallel training" self._model = self.model_class.load(temp_file) # type: ignore[arg-type] self.batch_size = self._model.batch_size # Disable train method def patch_train(function): @wraps(function) def wrapper(*args, **kwargs): return return wrapper # Add logger for parallel training self._model.set_logger(self.model.logger) self.model.train = patch_train(self.model.train) # type: ignore[assignment] # Hack: Re-add correct values at save time def patch_save(function): @wraps(function) def wrapper(*args, **kwargs): return*args, **kwargs) return wrapper = patch_save( # type: ignore[assignment] def train(self) -> None: self._model_ready = False self.process = Thread(target=self._train_thread, daemon=True) self.process.start() def _train_thread(self) -> None: self._model.train(gradient_steps=self.gradient_steps, batch_size=self.batch_size) self._model_ready = True def _on_step(self) -> bool: if self.sleep_time > 0: time.sleep(self.sleep_time) return True def _on_rollout_end(self) -> None: # Make mypy happy assert isinstance(self.model, (SAC, TQC)) if self._model_ready: self._model.replay_buffer = deepcopy(self.model.replay_buffer) self.model.set_parameters(deepcopy(self._model.get_parameters())) # type: ignore[arg-type] = # type: ignore[union-attr, attr-defined, assignment] if self.num_timesteps >= self._model.learning_starts: self.train() # Do not wait for the training loop to finish # self.process.join() def _on_training_end(self) -> None: # Wait for the thread to terminate if self.process is not None: if self.verbose > 0: print("Waiting for training thread to terminate") self.process.join()
[docs]class RawStatisticsCallback(BaseCallback): """ Callback used for logging raw episode data (return and episode length). """ def __init__(self, verbose=0): super().__init__(verbose) # Custom counter to reports stats # (and avoid reporting multiple values for the same step) self._timesteps_counter = 0 self._tensorboard_writer = None def _init_callback(self) -> None: assert self.logger is not None # Retrieve tensorboard writer to not flood the logger output for out_format in self.logger.output_formats: if isinstance(out_format, TensorBoardOutputFormat): self._tensorboard_writer = out_format assert self._tensorboard_writer is not None, "You must activate tensorboard logging when using RawStatisticsCallback" def _on_step(self) -> bool: for info in self.locals["infos"]: if "episode" in info: logger_dict = { "raw/rollouts/episodic_return": info["episode"]["r"], "raw/rollouts/episodic_length": info["episode"]["l"], } exclude_dict = {key: None for key in logger_dict.keys()} self._timesteps_counter += info["episode"]["l"] self._tensorboard_writer.write(logger_dict, exclude_dict, self._timesteps_counter) return True