Source code for rl_zoo3.utils

import argparse
import glob
import importlib
import os
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import gymnasium as gym
import stable_baselines3 as sb3  # noqa: F401
import torch as th  # noqa: F401
import yaml
from gymnasium import spaces
from huggingface_hub import HfApi
from huggingface_sb3 import EnvironmentName, ModelName
from sb3_contrib import ARS, QRDQN, TQC, TRPO, RecurrentPPO
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike  # noqa: F401
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv, VecFrameStack, VecNormalize

# For custom activation fn
from torch import nn as nn

ALGOS: Dict[str, Type[BaseAlgorithm]] = {
    "a2c": A2C,
    "ddpg": DDPG,
    "dqn": DQN,
    "ppo": PPO,
    "sac": SAC,
    "td3": TD3,
    # SB3 Contrib,
    "ars": ARS,
    "qrdqn": QRDQN,
    "tqc": TQC,
    "trpo": TRPO,
    "ppo_lstm": RecurrentPPO,
}


def flatten_dict_observations(env: gym.Env) -> gym.Env:
    assert isinstance(env.observation_space, spaces.Dict)
    return gym.wrappers.FlattenObservation(env)


[docs]def get_wrapper_class(hyperparams: Dict[str, Any], key: str = "env_wrapper") -> Optional[Callable[[gym.Env], gym.Env]]: """ Get one or more Gym environment wrapper class specified as a hyper parameter "env_wrapper". Works also for VecEnvWrapper with the key "vec_env_wrapper". e.g. env_wrapper: gym_minigrid.wrappers.FlatObsWrapper for multiple, specify a list: env_wrapper: - rl_zoo3.wrappers.PlotActionWrapper - rl_zoo3.wrappers.TimeFeatureWrapper :param hyperparams: :return: maybe a callable to wrap the environment with one or multiple gym.Wrapper """ def get_module_name(wrapper_name): return ".".join(wrapper_name.split(".")[:-1]) def get_class_name(wrapper_name): return wrapper_name.split(".")[-1] if key in hyperparams.keys(): wrapper_name = hyperparams.get(key) if wrapper_name is None: return None if not isinstance(wrapper_name, list): wrapper_names = [wrapper_name] else: wrapper_names = wrapper_name wrapper_classes = [] wrapper_kwargs = [] # Handle multiple wrappers for wrapper_name in wrapper_names: # Handle keyword arguments if isinstance(wrapper_name, dict): assert len(wrapper_name) == 1, ( "You have an error in the formatting " f"of your YAML file near {wrapper_name}. " "You should check the indentation." ) wrapper_dict = wrapper_name wrapper_name = next(iter(wrapper_dict.keys())) kwargs = wrapper_dict[wrapper_name] else: kwargs = {} wrapper_module = importlib.import_module(get_module_name(wrapper_name)) wrapper_class = getattr(wrapper_module, get_class_name(wrapper_name)) wrapper_classes.append(wrapper_class) wrapper_kwargs.append(kwargs) def wrap_env(env: gym.Env) -> gym.Env: """ :param env: :return: """ for wrapper_class, kwargs in zip(wrapper_classes, wrapper_kwargs): env = wrapper_class(env, **kwargs) return env return wrap_env else: return None
[docs]def get_class_by_name(name: str) -> Type: """ Imports and returns a class given the name, e.g. passing 'stable_baselines3.common.callbacks.CheckpointCallback' returns the CheckpointCallback class. :param name: :return: """ def get_module_name(name: str) -> str: return ".".join(name.split(".")[:-1]) def get_class_name(name: str) -> str: return name.split(".")[-1] module = importlib.import_module(get_module_name(name)) return getattr(module, get_class_name(name))
[docs]def get_callback_list(hyperparams: Dict[str, Any]) -> List[BaseCallback]: """ Get one or more Callback class specified as a hyper-parameter "callback". e.g. callback: stable_baselines3.common.callbacks.CheckpointCallback for multiple, specify a list: callback: - rl_zoo3.callbacks.PlotActionWrapper - stable_baselines3.common.callbacks.CheckpointCallback :param hyperparams: :return: """ callbacks: List[BaseCallback] = [] if "callback" in hyperparams.keys(): callback_name = hyperparams.get("callback") if callback_name is None: return callbacks if not isinstance(callback_name, list): callback_names = [callback_name] else: callback_names = callback_name # Handle multiple wrappers for callback_name in callback_names: # Handle keyword arguments if isinstance(callback_name, dict): assert len(callback_name) == 1, ( "You have an error in the formatting " f"of your YAML file near {callback_name}. " "You should check the indentation." ) callback_dict = callback_name callback_name = next(iter(callback_dict.keys())) kwargs = callback_dict[callback_name] else: kwargs = {} callback_class = get_class_by_name(callback_name) callbacks.append(callback_class(**kwargs)) return callbacks
[docs]def create_test_env( env_id: str, n_envs: int = 1, stats_path: Optional[str] = None, seed: int = 0, log_dir: Optional[str] = None, should_render: bool = True, hyperparams: Optional[Dict[str, Any]] = None, env_kwargs: Optional[Dict[str, Any]] = None, ) -> VecEnv: """ Create environment for testing a trained agent :param env_id: :param n_envs: number of processes :param stats_path: path to folder containing saved running averaged :param seed: Seed for random number generator :param log_dir: Where to log rewards :param should_render: For Pybullet env, display the GUI :param hyperparams: Additional hyperparams (ex: n_stack) :param env_kwargs: Optional keyword argument to pass to the env constructor :return: """ # Create the environment and wrap it if necessary assert hyperparams is not None env_wrapper = get_wrapper_class(hyperparams) hyperparams = {} if hyperparams is None else hyperparams if "env_wrapper" in hyperparams.keys(): del hyperparams["env_wrapper"] vec_env_kwargs: Dict[str, Any] = {} # Avoid potential shared memory issue vec_env_cls = SubprocVecEnv if n_envs > 1 else DummyVecEnv # Fix for gym 0.26, to keep old behavior env_kwargs = env_kwargs or {} env_kwargs = deepcopy(env_kwargs) if "render_mode" not in env_kwargs and should_render: env_kwargs.update(render_mode="human") spec = gym.spec(env_id) # Define make_env here, so it works with subprocesses # when the registry was modified with `--gym-packages` # See https://github.com/HumanCompatibleAI/imitation/pull/160 def make_env(**kwargs) -> gym.Env: return spec.make(**kwargs) env = make_vec_env( make_env, n_envs=n_envs, monitor_dir=log_dir, seed=seed, wrapper_class=env_wrapper, env_kwargs=env_kwargs, vec_env_cls=vec_env_cls, # type: ignore[arg-type] vec_env_kwargs=vec_env_kwargs, ) if "vec_env_wrapper" in hyperparams.keys(): vec_env_wrapper = get_wrapper_class(hyperparams, "vec_env_wrapper") assert vec_env_wrapper is not None env = vec_env_wrapper(env) # type: ignore[assignment, arg-type] del hyperparams["vec_env_wrapper"] # Load saved stats for normalizing input and rewards # And optionally stack frames if stats_path is not None: if hyperparams["normalize"]: print("Loading running average") print(f"with params: {hyperparams['normalize_kwargs']}") path_ = os.path.join(stats_path, "vecnormalize.pkl") if os.path.exists(path_): env = VecNormalize.load(path_, env) # Deactivate training and reward normalization env.training = False env.norm_reward = False else: raise ValueError(f"VecNormalize stats {path_} not found") n_stack = hyperparams.get("frame_stack", 0) if n_stack > 0: print(f"Stacking {n_stack} frames") env = VecFrameStack(env, n_stack) return env
[docs]def linear_schedule(initial_value: Union[float, str]) -> Callable[[float], float]: """ Linear learning rate schedule. :param initial_value: (float or str) :return: (function) """ # Force conversion to float initial_value_ = float(initial_value) def func(progress_remaining: float) -> float: """ Progress will decrease from 1 (beginning) to 0 :param progress_remaining: (float) :return: (float) """ return progress_remaining * initial_value_ return func
[docs]def get_trained_models(log_folder: str) -> Dict[str, Tuple[str, str]]: """ :param log_folder: Root log folder :return: Dict representing the trained agents """ trained_models = {} for algo in os.listdir(log_folder): if not os.path.isdir(os.path.join(log_folder, algo)): continue for model_folder in os.listdir(os.path.join(log_folder, algo)): args_files = glob.glob(os.path.join(log_folder, algo, model_folder, "*/args.yml")) if len(args_files) != 1: continue # we expect only one sub-folder with an args.yml file with open(args_files[0]) as fh: env_id = yaml.load(fh, Loader=yaml.UnsafeLoader)["env"] model_name = ModelName(algo, EnvironmentName(env_id)) trained_models[model_name] = (algo, env_id) return trained_models
[docs]def get_hf_trained_models(organization: str = "sb3", check_filename: bool = False) -> Dict[str, Tuple[str, str]]: """ Get pretrained models, available on the Hugginface hub for a given organization. :param organization: Huggingface organization Stable-Baselines (SB3) one is the default. :param check_filename: Perform additional check per model to be sure they match the RL Zoo convention. (this will slow down things as it requires one API call per model) :return: Dict representing the trained agents """ api = HfApi() models = api.list_models(author=organization, cardData=True) trained_models = {} for model in models: # Try to extract algorithm and environment id from model card try: env_id = model.cardData["model-index"][0]["results"][0]["dataset"]["name"] algo = model.cardData["model-index"][0]["name"].lower() # RecurrentPPO alias is "ppo_lstm" in the rl zoo if algo == "recurrentppo": algo = "ppo_lstm" except (KeyError, IndexError): print(f"Skipping {model.modelId}") continue # skip model if name env id or algo name could not be found env_name = EnvironmentName(env_id) model_name = ModelName(algo, env_name) # check if there is a model file in the repo if check_filename and not any(f.rfilename == model_name.filename for f in api.model_info(model.modelId).siblings): continue # skip model if the repo contains no properly named model file trained_models[model_name] = (algo, env_id) return trained_models
[docs]def get_latest_run_id(log_path: str, env_name: EnvironmentName) -> int: """ Returns the latest run number for the given log name and log path, by finding the greatest number in the directories. :param log_path: path to log folder :param env_name: :return: latest run number """ max_run_id = 0 for path in glob.glob(os.path.join(log_path, env_name + "_[0-9]*")): run_id = path.split("_")[-1] path_without_run_id = path[: -len(run_id) - 1] if path_without_run_id.endswith(env_name) and run_id.isdigit() and int(run_id) > max_run_id: max_run_id = int(run_id) return max_run_id
[docs]def get_saved_hyperparams( stats_path: str, norm_reward: bool = False, test_mode: bool = False, ) -> Tuple[Dict[str, Any], Optional[str]]: """ Retrieve saved hyperparameters given a path. Return empty dict and None if the path is not valid. :param stats_path: :param norm_reward: :param test_mode: :return: """ hyperparams: Dict[str, Any] = {} if not os.path.isdir(stats_path): return hyperparams, None else: config_file = os.path.join(stats_path, "config.yml") if os.path.isfile(config_file): # Load saved hyperparameters with open(os.path.join(stats_path, "config.yml")) as f: hyperparams = yaml.load(f, Loader=yaml.UnsafeLoader) hyperparams["normalize"] = hyperparams.get("normalize", False) else: obs_rms_path = os.path.join(stats_path, "obs_rms.pkl") hyperparams["normalize"] = os.path.isfile(obs_rms_path) # Load normalization params if hyperparams["normalize"]: if isinstance(hyperparams["normalize"], str): normalize_kwargs = eval(hyperparams["normalize"]) if test_mode: normalize_kwargs["norm_reward"] = norm_reward else: normalize_kwargs = {"norm_obs": hyperparams["normalize"], "norm_reward": norm_reward} hyperparams["normalize_kwargs"] = normalize_kwargs return hyperparams, stats_path
[docs]class StoreDict(argparse.Action): """ Custom argparse action for storing dict. In: args1:0.0 args2:"dict(a=1)" Out: {'args1': 0.0, arg2: dict(a=1)} """ def __init__(self, option_strings, dest, nargs=None, **kwargs): self._nargs = nargs super().__init__(option_strings, dest, nargs=nargs, **kwargs) def __call__(self, parser, namespace, values, option_string=None): arg_dict = {} for arguments in values: key = arguments.split(":")[0] value = ":".join(arguments.split(":")[1:]) # Evaluate the string as python code arg_dict[key] = eval(value) setattr(namespace, self.dest, arg_dict)
def get_model_path( exp_id: int, folder: str, algo: str, env_name: EnvironmentName, load_best: bool = False, load_checkpoint: Optional[str] = None, load_last_checkpoint: bool = False, ) -> Tuple[str, str, str]: if exp_id == 0: exp_id = get_latest_run_id(os.path.join(folder, algo), env_name) print(f"Loading latest experiment, id={exp_id}") # Sanity checks if exp_id > 0: log_path = os.path.join(folder, algo, f"{env_name}_{exp_id}") else: log_path = os.path.join(folder, algo) assert os.path.isdir(log_path), f"The {log_path} folder was not found" model_name = ModelName(algo, env_name) if load_best: model_path = os.path.join(log_path, "best_model.zip") name_prefix = f"best-model-{model_name}" elif load_checkpoint is not None: model_path = os.path.join(log_path, f"rl_model_{load_checkpoint}_steps.zip") name_prefix = f"checkpoint-{load_checkpoint}-{model_name}" elif load_last_checkpoint: checkpoints = glob.glob(os.path.join(log_path, "rl_model_*_steps.zip")) if len(checkpoints) == 0: raise ValueError(f"No checkpoint found for {algo} on {env_name}, path: {log_path}") def step_count(checkpoint_path: str) -> int: # path follow the pattern "rl_model_*_steps.zip", we count from the back to ignore any other _ in the path return int(checkpoint_path.split("_")[-2]) checkpoints = sorted(checkpoints, key=step_count) model_path = checkpoints[-1] name_prefix = f"checkpoint-{step_count(model_path)}-{model_name}" else: # Default: load latest model model_path = os.path.join(log_path, f"{env_name}.zip") name_prefix = f"final-model-{model_name}" found = os.path.isfile(model_path) if not found: raise ValueError(f"No model found for {algo} on {env_name}, path: {model_path}") return name_prefix, model_path, log_path