Utils
- class rl_zoo3.utils.StoreDict(option_strings, dest, nargs=None, **kwargs)[source]
Custom argparse action for storing dict.
In: args1:0.0 args2:”dict(a=1)” Out: {‘args1’: 0.0, arg2: dict(a=1)}
- rl_zoo3.utils.create_test_env(env_id, n_envs=1, stats_path=None, seed=0, log_dir=None, should_render=True, hyperparams=None, env_kwargs=None, vec_env_cls=None, vec_env_kwargs=None)[source]
Create environment for testing a trained agent
- Parameters:
env_id (str)
n_envs (int) – number of processes
stats_path (str | None) – path to folder containing saved running averaged
seed (int) – Seed for random number generator
log_dir (str | None) – Where to log rewards
should_render (bool) – For Pybullet env, display the GUI
hyperparams (dict[str, Any] | None) – Additional hyperparams (ex: n_stack)
env_kwargs (dict[str, Any] | None) – Optional keyword argument to pass to the env constructor
vec_env_cls (type[VecEnv] | None) –
VecEnv
class constructor.vec_env_kwargs (dict[str, Any] | None) – Keyword arguments to pass to the
VecEnv
class constructor.
- Returns:
- Return type:
VecEnv
- rl_zoo3.utils.get_callback_list(hyperparams)[source]
Get one or more Callback class specified as a hyper-parameter “callback”. e.g. callback: stable_baselines3.common.callbacks.CheckpointCallback
for multiple, specify a list:
- callback:
rl_zoo3.callbacks.PlotActionWrapper
stable_baselines3.common.callbacks.CheckpointCallback
- Parameters:
hyperparams (dict[str, Any])
- Returns:
- Return type:
list[BaseCallback]
- rl_zoo3.utils.get_class_by_name(name)[source]
Imports and returns a class given the name, e.g. passing ‘stable_baselines3.common.callbacks.CheckpointCallback’ returns the CheckpointCallback class.
- Parameters:
name (str)
- Returns:
- Return type:
type
- rl_zoo3.utils.get_hf_trained_models(organization='sb3', check_filename=False)[source]
Get pretrained models, available on the Hugginface hub for a given organization.
- Parameters:
organization (str) – Huggingface organization Stable-Baselines (SB3) one is the default.
check_filename (bool) – Perform additional check per model to be sure they match the RL Zoo convention. (this will slow down things as it requires one API call per model)
- Returns:
Dict representing the trained agents
- Return type:
dict[str, tuple[str, str]]
- rl_zoo3.utils.get_latest_run_id(log_path, env_name)[source]
Returns the latest run number for the given log name and log path, by finding the greatest number in the directories.
- Parameters:
log_path (str) – path to log folder
env_name (EnvironmentName)
- Returns:
latest run number
- Return type:
int
- rl_zoo3.utils.get_saved_hyperparams(stats_path, norm_reward=False, test_mode=False)[source]
Retrieve saved hyperparameters given a path. Return empty dict and None if the path is not valid.
- Parameters:
stats_path (str)
norm_reward (bool)
test_mode (bool)
- Returns:
- Return type:
tuple[dict[str, Any], str | None]
- rl_zoo3.utils.get_trained_models(log_folder)[source]
- Parameters:
log_folder (str) – Root log folder
- Returns:
Dict representing the trained agents
- Return type:
dict[str, tuple[str, str]]
- rl_zoo3.utils.get_wrapper_class(hyperparams, key='env_wrapper')[source]
Get one or more Gym environment wrapper class specified as a hyper parameter “env_wrapper”. Works also for VecEnvWrapper with the key “vec_env_wrapper”.
e.g. env_wrapper: gym_minigrid.wrappers.FlatObsWrapper
for multiple, specify a list:
- env_wrapper:
rl_zoo3.wrappers.PlotActionWrapper
rl_zoo3.wrappers.TimeFeatureWrapper
- Parameters:
hyperparams (dict[str, Any])
key (str)
- Returns:
maybe a callable to wrap the environment with one or multiple gym.Wrapper
- Return type:
Callable[[Env], Env] | None