import os
import tempfile
import time
from copy import deepcopy
from functools import wraps
from threading import Thread
from typing import Optional, Type, Union
import optuna
from sb3_contrib import TQC
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
from stable_baselines3.common.logger import TensorBoardOutputFormat
from stable_baselines3.common.vec_env import VecEnv
[docs]class TrialEvalCallback(EvalCallback):
"""
Callback used for evaluating and reporting a trial.
"""
def __init__(
self,
eval_env: VecEnv,
trial: optuna.Trial,
n_eval_episodes: int = 5,
eval_freq: int = 10000,
deterministic: bool = True,
verbose: int = 0,
best_model_save_path: Optional[str] = None,
log_path: Optional[str] = None,
) -> None:
super().__init__(
eval_env=eval_env,
n_eval_episodes=n_eval_episodes,
eval_freq=eval_freq,
deterministic=deterministic,
verbose=verbose,
best_model_save_path=best_model_save_path,
log_path=log_path,
)
self.trial = trial
self.eval_idx = 0
self.is_pruned = False
def _on_step(self) -> bool:
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
super()._on_step()
self.eval_idx += 1
# report best or report current ?
# report num_timesteps or elasped time ?
self.trial.report(self.last_mean_reward, self.eval_idx)
# Prune trial if need
if self.trial.should_prune():
self.is_pruned = True
return False
return True
[docs]class SaveVecNormalizeCallback(BaseCallback):
"""
Callback for saving a VecNormalize wrapper every ``save_freq`` steps
:param save_freq: (int)
:param save_path: (str) Path to the folder where ``VecNormalize`` will be saved, as ``vecnormalize.pkl``
:param name_prefix: (str) Common prefix to the saved ``VecNormalize``, if None (default)
only one file will be kept.
"""
def __init__(self, save_freq: int, save_path: str, name_prefix: Optional[str] = None, verbose: int = 0):
super().__init__(verbose)
self.save_freq = save_freq
self.save_path = save_path
self.name_prefix = name_prefix
def _init_callback(self) -> None:
# Create folder if needed
if self.save_path is not None:
os.makedirs(self.save_path, exist_ok=True)
def _on_step(self) -> bool:
# make mypy happy
assert self.model is not None
if self.n_calls % self.save_freq == 0:
if self.name_prefix is not None:
path = os.path.join(self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps.pkl")
else:
path = os.path.join(self.save_path, "vecnormalize.pkl")
if self.model.get_vec_normalize_env() is not None:
self.model.get_vec_normalize_env().save(path) # type: ignore[union-attr]
if self.verbose > 1:
print(f"Saving VecNormalize to {path}")
return True
[docs]class ParallelTrainCallback(BaseCallback):
"""
Callback to explore (collect experience) and train (do gradient steps)
at the same time using two separate threads.
Normally used with off-policy algorithms and `train_freq=(1, "episode")`.
TODO:
- blocking mode: wait for the model to finish updating the policy before collecting new experience
at the end of a rollout
- force sync mode: stop training to update to the latest policy for collecting
new experience
:param gradient_steps: Number of gradient steps to do before
sending the new policy
:param verbose: Verbosity level
:param sleep_time: Limit the fps in the thread collecting experience.
"""
def __init__(self, gradient_steps: int = 100, verbose: int = 0, sleep_time: float = 0.0):
super().__init__(verbose)
self.batch_size = 0
self._model_ready = True
self._model: Union[SAC, TQC]
self.gradient_steps = gradient_steps
self.process: Thread
self.model_class: Union[Type[SAC], Type[TQC]]
self.sleep_time = sleep_time
def _init_callback(self) -> None:
temp_file = tempfile.TemporaryFile()
# Windows TemporaryFile is not a io Buffer
# we save the model in the logs/ folder
if os.name == "nt":
temp_file = os.path.join("logs", "model_tmp.zip") # type: ignore[arg-type,assignment]
# make mypy happy
assert isinstance(self.model, (SAC, TQC)), f"{self.model} is not supported for parallel training"
self.model.save(temp_file) # type: ignore[arg-type]
# TODO: add support for other algorithms
for model_class in [SAC, TQC]:
if isinstance(self.model, model_class):
self.model_class = model_class # type: ignore[assignment]
break
assert self.model_class is not None, f"{self.model} is not supported for parallel training"
self._model = self.model_class.load(temp_file) # type: ignore[arg-type]
self.batch_size = self._model.batch_size
# Disable train method
def patch_train(function):
@wraps(function)
def wrapper(*args, **kwargs):
return
return wrapper
# Add logger for parallel training
self._model.set_logger(self.model.logger)
self.model.train = patch_train(self.model.train) # type: ignore[assignment]
# Hack: Re-add correct values at save time
def patch_save(function):
@wraps(function)
def wrapper(*args, **kwargs):
return self._model.save(*args, **kwargs)
return wrapper
self.model.save = patch_save(self.model.save) # type: ignore[assignment]
def train(self) -> None:
self._model_ready = False
self.process = Thread(target=self._train_thread, daemon=True)
self.process.start()
def _train_thread(self) -> None:
self._model.train(gradient_steps=self.gradient_steps, batch_size=self.batch_size)
self._model_ready = True
def _on_step(self) -> bool:
if self.sleep_time > 0:
time.sleep(self.sleep_time)
return True
def _on_rollout_end(self) -> None:
# Make mypy happy
assert isinstance(self.model, (SAC, TQC))
if self._model_ready:
self._model.replay_buffer = deepcopy(self.model.replay_buffer)
self.model.set_parameters(deepcopy(self._model.get_parameters())) # type: ignore[arg-type]
self.model.actor = self.model.policy.actor # type: ignore[union-attr, attr-defined, assignment]
if self.num_timesteps >= self._model.learning_starts:
self.train()
# Do not wait for the training loop to finish
# self.process.join()
def _on_training_end(self) -> None:
# Wait for the thread to terminate
if self.process is not None:
if self.verbose > 0:
print("Waiting for training thread to terminate")
self.process.join()
[docs]class RawStatisticsCallback(BaseCallback):
"""
Callback used for logging raw episode data (return and episode length).
"""
def __init__(self, verbose=0):
super().__init__(verbose)
# Custom counter to reports stats
# (and avoid reporting multiple values for the same step)
self._timesteps_counter = 0
self._tensorboard_writer = None
def _init_callback(self) -> None:
assert self.logger is not None
# Retrieve tensorboard writer to not flood the logger output
for out_format in self.logger.output_formats:
if isinstance(out_format, TensorBoardOutputFormat):
self._tensorboard_writer = out_format
assert self._tensorboard_writer is not None, "You must activate tensorboard logging when using RawStatisticsCallback"
def _on_step(self) -> bool:
for info in self.locals["infos"]:
if "episode" in info:
logger_dict = {
"raw/rollouts/episodic_return": info["episode"]["r"],
"raw/rollouts/episodic_length": info["episode"]["l"],
}
exclude_dict = {key: None for key in logger_dict.keys()}
self._timesteps_counter += info["episode"]["l"]
self._tensorboard_writer.write(logger_dict, exclude_dict, self._timesteps_counter)
return True