Utils¶
- class rl_zoo3.utils.StoreDict(option_strings, dest, nargs=None, **kwargs)[source]¶
Custom argparse action for storing dict.
In: args1:0.0 args2:”dict(a=1)” Out: {‘args1’: 0.0, arg2: dict(a=1)}
- rl_zoo3.utils.create_test_env(env_id, n_envs=1, stats_path=None, seed=0, log_dir=None, should_render=True, hyperparams=None, env_kwargs=None)[source]¶
Create environment for testing a trained agent
- Parameters:
env_id (
str
) –n_envs (
int
) – number of processesstats_path (
Optional
[str
]) – path to folder containing saved running averagedseed (
int
) – Seed for random number generatorlog_dir (
Optional
[str
]) – Where to log rewardsshould_render (
bool
) – For Pybullet env, display the GUIhyperparams (
Optional
[Dict
[str
,Any
]]) – Additional hyperparams (ex: n_stack)env_kwargs (
Optional
[Dict
[str
,Any
]]) – Optional keyword argument to pass to the env constructor
- Return type:
VecEnv
- Returns:
- rl_zoo3.utils.get_callback_list(hyperparams)[source]¶
Get one or more Callback class specified as a hyper-parameter “callback”. e.g. callback: stable_baselines3.common.callbacks.CheckpointCallback
for multiple, specify a list:
- callback:
rl_zoo3.callbacks.PlotActionWrapper
stable_baselines3.common.callbacks.CheckpointCallback
- Parameters:
hyperparams (
Dict
[str
,Any
]) –- Return type:
List
[BaseCallback
]- Returns:
- rl_zoo3.utils.get_class_by_name(name)[source]¶
Imports and returns a class given the name, e.g. passing ‘stable_baselines3.common.callbacks.CheckpointCallback’ returns the CheckpointCallback class.
- Parameters:
name (
str
) –- Return type:
Type
- Returns:
- rl_zoo3.utils.get_hf_trained_models(organization='sb3', check_filename=False)[source]¶
Get pretrained models, available on the Hugginface hub for a given organization.
- Parameters:
organization (
str
) – Huggingface organization Stable-Baselines (SB3) one is the default.check_filename (
bool
) – Perform additional check per model to be sure they match the RL Zoo convention. (this will slow down things as it requires one API call per model)
- Return type:
Dict
[str
,Tuple
[str
,str
]]- Returns:
Dict representing the trained agents
- rl_zoo3.utils.get_latest_run_id(log_path, env_name)[source]¶
Returns the latest run number for the given log name and log path, by finding the greatest number in the directories.
- Parameters:
log_path (
str
) – path to log folderenv_name (
EnvironmentName
) –
- Return type:
int
- Returns:
latest run number
- rl_zoo3.utils.get_saved_hyperparams(stats_path, norm_reward=False, test_mode=False)[source]¶
Retrieve saved hyperparameters given a path. Return empty dict and None if the path is not valid.
- Parameters:
stats_path (
str
) –norm_reward (
bool
) –test_mode (
bool
) –
- Return type:
Tuple
[Dict
[str
,Any
],Optional
[str
]]- Returns:
- rl_zoo3.utils.get_trained_models(log_folder)[source]¶
- Parameters:
log_folder (
str
) – Root log folder- Return type:
Dict
[str
,Tuple
[str
,str
]]- Returns:
Dict representing the trained agents
- rl_zoo3.utils.get_wrapper_class(hyperparams, key='env_wrapper')[source]¶
Get one or more Gym environment wrapper class specified as a hyper parameter “env_wrapper”. Works also for VecEnvWrapper with the key “vec_env_wrapper”.
e.g. env_wrapper: gym_minigrid.wrappers.FlatObsWrapper
for multiple, specify a list:
- env_wrapper:
rl_zoo3.wrappers.PlotActionWrapper
rl_zoo3.wrappers.TimeFeatureWrapper
- Parameters:
hyperparams (
Dict
[str
,Any
]) –- Return type:
Optional
[Callable
[[Env
],Env
]]- Returns:
maybe a callable to wrap the environment with one or multiple gym.Wrapper