diff --git a/conda_env.yml b/conda_env.yml index e79c891..c4993ca 100644 --- a/conda_env.yml +++ b/conda_env.yml @@ -1,4 +1,4 @@ -name: pytorch_sac_ae +name: pytorch_sac_ae2 channels: - defaults dependencies: diff --git a/train.py b/train.py index 4f6cde4..02769f7 100644 --- a/train.py +++ b/train.py @@ -26,6 +26,9 @@ def parse_args(): parser.add_argument('--image_size', default=84, type=int) parser.add_argument('--action_repeat', default=1, type=int) parser.add_argument('--frame_stack', default=3, type=int) + parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) + parser.add_argument('--resource_files', type=str) + parser.add_argument('--total_frames', default=10000, type=int) # replay buffer parser.add_argument('--replay_buffer_capacity', default=1000000, type=int) # train