il_representations.il package¶
Module contents¶
The il package contains re-implementations of IL algorithms used in our joint training experiments.
Submodules¶
il_representations.il.bc module¶
il_representations.il.bc_support module¶
il_representations.il.disc_rew_nets module¶
il_representations.il.gail_pol_save module¶
- class il_representations.il.gail_pol_save.GAILSavePolicyCallback(ppo_algo, save_every_n_steps, save_dir, *, save_template='policy_{timesteps:08d}_steps.pt')¶
Bases:
object
This callback can be passed to AdversarialTrainer.train() to save a policy snapshot every save_every_n_steps time steps.
il_representations.il.score_logging module¶
SB3 score-logging callback for MAGICAL (but should be safe to add to include when using any environment—if the desired eval_score key is not in infos, then it won’t add any log entries).
- class il_representations.il.score_logging.SB3ScoreLoggingCallback(*args: Any, **kwargs: Any)¶
Bases:
stable_baselines3.common.callbacks.stable_baselines3.common.callbacks.BaseCallback._name
Callback for SB3 RL algorithms which extracts the ‘eval_score’ from the step info dict (if it exists) and includes it in the logs. Useful for MAGICAL, which reports end-of-trajectory performance using eval_score.
Tested for PPO, but may work for other algorithms too.
il_representations.il.utils module¶
Utilities that are helpful for several pieces of IL code (e.g. in both il_train.py and joint_training.py).
- il_representations.il.utils.add_infos(data_iter)¶
Add a dummy ‘infos’ value to each dict in a data stream.
- il_representations.il.utils.streaming_extract_keys(*keys_to_keep)¶
Filter a generator of dicts to keep only the specified keys.