Utils¶
- class rl_zoo3.utils.StoreDict(option_strings, dest, nargs=None, **kwargs)[source]¶
Custom argparse action for storing dict.
In: args1:0.0 args2:”dict(a=1)” Out: {‘args1’: 0.0, arg2: dict(a=1)}
- rl_zoo3.utils.create_test_env(env_id, n_envs=1, stats_path=None, seed=0, log_dir=None, should_render=True, hyperparams=None, env_kwargs=None)[source]¶
Create environment for testing a trained agent
- Parameters:
env_id (
str) –n_envs (
int) – number of processesstats_path (
Optional[str]) – path to folder containing saved running averagedseed (
int) – Seed for random number generatorlog_dir (
Optional[str]) – Where to log rewardsshould_render (
bool) – For Pybullet env, display the GUIhyperparams (
Optional[Dict[str,Any]]) – Additional hyperparams (ex: n_stack)env_kwargs (
Optional[Dict[str,Any]]) – Optional keyword argument to pass to the env constructor
- Return type:
VecEnv- Returns:
- rl_zoo3.utils.get_callback_list(hyperparams)[source]¶
Get one or more Callback class specified as a hyper-parameter “callback”. e.g. callback: stable_baselines3.common.callbacks.CheckpointCallback
for multiple, specify a list:
- callback:
rl_zoo3.callbacks.PlotActionWrapper
stable_baselines3.common.callbacks.CheckpointCallback
- Parameters:
hyperparams (
Dict[str,Any]) –- Return type:
List[BaseCallback]- Returns:
- rl_zoo3.utils.get_class_by_name(name)[source]¶
Imports and returns a class given the name, e.g. passing ‘stable_baselines3.common.callbacks.CheckpointCallback’ returns the CheckpointCallback class.
- Parameters:
name (
str) –- Return type:
Type- Returns:
- rl_zoo3.utils.get_hf_trained_models(organization='sb3', check_filename=False)[source]¶
Get pretrained models, available on the Hugginface hub for a given organization.
- Parameters:
organization (
str) – Huggingface organization Stable-Baselines (SB3) one is the default.check_filename (
bool) – Perform additional check per model to be sure they match the RL Zoo convention. (this will slow down things as it requires one API call per model)
- Return type:
Dict[str,Tuple[str,str]]- Returns:
Dict representing the trained agents
- rl_zoo3.utils.get_latest_run_id(log_path, env_name)[source]¶
Returns the latest run number for the given log name and log path, by finding the greatest number in the directories.
- Parameters:
log_path (
str) – path to log folderenv_name (
EnvironmentName) –
- Return type:
int- Returns:
latest run number
- rl_zoo3.utils.get_saved_hyperparams(stats_path, norm_reward=False, test_mode=False)[source]¶
Retrieve saved hyperparameters given a path. Return empty dict and None if the path is not valid.
- Parameters:
stats_path (
str) –norm_reward (
bool) –test_mode (
bool) –
- Return type:
Tuple[Dict[str,Any],Optional[str]]- Returns:
- rl_zoo3.utils.get_trained_models(log_folder)[source]¶
- Parameters:
log_folder (
str) – Root log folder- Return type:
Dict[str,Tuple[str,str]]- Returns:
Dict representing the trained agents
- rl_zoo3.utils.get_wrapper_class(hyperparams, key='env_wrapper')[source]¶
Get one or more Gym environment wrapper class specified as a hyper parameter “env_wrapper”. Works also for VecEnvWrapper with the key “vec_env_wrapper”.
e.g. env_wrapper: gym_minigrid.wrappers.FlatObsWrapper
for multiple, specify a list:
- env_wrapper:
rl_zoo3.wrappers.PlotActionWrapper
rl_zoo3.wrappers.TimeFeatureWrapper
- Parameters:
hyperparams (
Dict[str,Any]) –- Return type:
Optional[Callable[[Env],Env]]- Returns:
maybe a callable to wrap the environment with one or multiple gym.Wrapper