il_representations.il package

Module contents

The il package contains re-implementations of IL algorithms used in our joint training experiments.

Submodules

il_representations.il.bc module

il_representations.il.bc_support module

il_representations.il.disc_rew_nets module

il_representations.il.gail_pol_save module

class il_representations.il.gail_pol_save.GAILSavePolicyCallback(ppo_algo, save_every_n_steps, save_dir, *, save_template='policy_{timesteps:08d}_steps.pt')

Bases: object

This callback can be passed to AdversarialTrainer.train() to save a policy snapshot every save_every_n_steps time steps.

il_representations.il.score_logging module

SB3 score-logging callback for MAGICAL (but should be safe to add to include when using any environment—if the desired eval_score key is not in infos, then it won’t add any log entries).

class il_representations.il.score_logging.SB3ScoreLoggingCallback(*args: Any, **kwargs: Any)

Bases: stable_baselines3.common.callbacks.stable_baselines3.common.callbacks.BaseCallback._name

Callback for SB3 RL algorithms which extracts the ‘eval_score’ from the step info dict (if it exists) and includes it in the logs. Useful for MAGICAL, which reports end-of-trajectory performance using eval_score.

Tested for PPO, but may work for other algorithms too.

il_representations.il.utils module

Utilities that are helpful for several pieces of IL code (e.g. in both il_train.py and joint_training.py).

il_representations.il.utils.add_infos(data_iter)

Add a dummy ‘infos’ value to each dict in a data stream.

il_representations.il.utils.streaming_extract_keys(*keys_to_keep)

Filter a generator of dicts to keep only the specified keys.