Utils

class rl_zoo3.utils.StoreDict(option_strings, dest, nargs=None, **kwargs)[source]

Custom argparse action for storing dict.

In: args1:0.0 args2:”dict(a=1)” Out: {‘args1’: 0.0, arg2: dict(a=1)}

rl_zoo3.utils.create_test_env(env_id, n_envs=1, stats_path=None, seed=0, log_dir=None, should_render=True, hyperparams=None, env_kwargs=None)[source]

Create environment for testing a trained agent

Parameters:
  • env_id (str) –

  • n_envs (int) – number of processes

  • stats_path (str | None) – path to folder containing saved running averaged

  • seed (int) – Seed for random number generator

  • log_dir (str | None) – Where to log rewards

  • should_render (bool) – For Pybullet env, display the GUI

  • hyperparams (Dict[str, Any] | None) – Additional hyperparams (ex: n_stack)

  • env_kwargs (Dict[str, Any] | None) – Optional keyword argument to pass to the env constructor

Returns:

Return type:

VecEnv

rl_zoo3.utils.get_callback_list(hyperparams)[source]

Get one or more Callback class specified as a hyper-parameter “callback”. e.g. callback: stable_baselines3.common.callbacks.CheckpointCallback

for multiple, specify a list:

callback:
  • rl_zoo3.callbacks.PlotActionWrapper

  • stable_baselines3.common.callbacks.CheckpointCallback

Parameters:

hyperparams (Dict[str, Any]) –

Returns:

Return type:

List[BaseCallback]

rl_zoo3.utils.get_class_by_name(name)[source]

Imports and returns a class given the name, e.g. passing ‘stable_baselines3.common.callbacks.CheckpointCallback’ returns the CheckpointCallback class.

Parameters:

name (str) –

Returns:

Return type:

Type

rl_zoo3.utils.get_hf_trained_models(organization='sb3', check_filename=False)[source]

Get pretrained models, available on the Hugginface hub for a given organization.

Parameters:
  • organization (str) – Huggingface organization Stable-Baselines (SB3) one is the default.

  • check_filename (bool) – Perform additional check per model to be sure they match the RL Zoo convention. (this will slow down things as it requires one API call per model)

Returns:

Dict representing the trained agents

Return type:

Dict[str, Tuple[str, str]]

rl_zoo3.utils.get_latest_run_id(log_path, env_name)[source]

Returns the latest run number for the given log name and log path, by finding the greatest number in the directories.

Parameters:
  • log_path (str) – path to log folder

  • env_name (EnvironmentName) –

Returns:

latest run number

Return type:

int

rl_zoo3.utils.get_saved_hyperparams(stats_path, norm_reward=False, test_mode=False)[source]

Retrieve saved hyperparameters given a path. Return empty dict and None if the path is not valid.

Parameters:
  • stats_path (str) –

  • norm_reward (bool) –

  • test_mode (bool) –

Returns:

Return type:

Tuple[Dict[str, Any], str | None]

rl_zoo3.utils.get_trained_models(log_folder)[source]
Parameters:

log_folder (str) – Root log folder

Returns:

Dict representing the trained agents

Return type:

Dict[str, Tuple[str, str]]

rl_zoo3.utils.get_wrapper_class(hyperparams, key='env_wrapper')[source]

Get one or more Gym environment wrapper class specified as a hyper parameter “env_wrapper”. Works also for VecEnvWrapper with the key “vec_env_wrapper”.

e.g. env_wrapper: gym_minigrid.wrappers.FlatObsWrapper

for multiple, specify a list:

env_wrapper:
  • rl_zoo3.wrappers.PlotActionWrapper

  • rl_zoo3.wrappers.TimeFeatureWrapper

Parameters:
  • hyperparams (Dict[str, Any]) –

  • key (str) –

Returns:

maybe a callable to wrap the environment with one or multiple gym.Wrapper

Return type:

Callable[[Env], Env] | None

rl_zoo3.utils.linear_schedule(initial_value)[source]

Linear learning rate schedule.

Parameters:

initial_value (float | str) – (float or str)

Returns:

(function)

Return type:

Callable[[float], float]