Compare commits
No commits in common. "76442c02f5aed690065ee33f83db02513ab597aa" and "82cee4746f67699ef8a47f3fa2843e7f67bf42cb" have entirely different histories.
76442c02f5
...
82cee4746f
21
LICENSE
21
LICENSE
@ -1,21 +0,0 @@
|
|||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2020 Denis Yarats
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
12
README.md
12
README.md
@ -8,18 +8,6 @@ This is PyTorch implementation of SAC+AE from
|
|||||||
|
|
||||||
[[Paper]](https://arxiv.org/abs/1910.01741) [[Webpage]](https://sites.google.com/view/sac-ae/home)
|
[[Paper]](https://arxiv.org/abs/1910.01741) [[Webpage]](https://sites.google.com/view/sac-ae/home)
|
||||||
|
|
||||||
## Citation
|
|
||||||
If you use this repo in your research, please consider citing the paper as follows
|
|
||||||
```
|
|
||||||
@article{yarats2019improving,
|
|
||||||
title={Improving Sample Efficiency in Model-Free Reinforcement Learning from Images},
|
|
||||||
author={Denis Yarats and Amy Zhang and Ilya Kostrikov and Brandon Amos and Joelle Pineau and Rob Fergus},
|
|
||||||
year={2019},
|
|
||||||
eprint={1910.01741},
|
|
||||||
archivePrefix={arXiv}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
We assume you have access to a gpu that can run CUDA 9.2. Then, the simplest way to install all required dependencies is to create an anaconda environment by running:
|
We assume you have access to a gpu that can run CUDA 9.2. Then, the simplest way to install all required dependencies is to create an anaconda environment by running:
|
||||||
```
|
```
|
||||||
|
@ -11,7 +11,7 @@ dependencies:
|
|||||||
- pip:
|
- pip:
|
||||||
- termcolor
|
- termcolor
|
||||||
- git+git://github.com/deepmind/dm_control.git
|
- git+git://github.com/deepmind/dm_control.git
|
||||||
- git+git://github.com/denisyarats/dmc2gym.git
|
- git+git://github.com/1nadequacy/dmc2gym.git
|
||||||
- tb-nightly
|
- tb-nightly
|
||||||
- imageio
|
- imageio
|
||||||
- imageio-ffmpeg
|
- imageio-ffmpeg
|
59
train.py
59
train.py
@ -26,16 +26,13 @@ def parse_args():
|
|||||||
parser.add_argument('--image_size', default=84, type=int)
|
parser.add_argument('--image_size', default=84, type=int)
|
||||||
parser.add_argument('--action_repeat', default=1, type=int)
|
parser.add_argument('--action_repeat', default=1, type=int)
|
||||||
parser.add_argument('--frame_stack', default=3, type=int)
|
parser.add_argument('--frame_stack', default=3, type=int)
|
||||||
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
|
||||||
parser.add_argument('--resource_files', type=str)
|
|
||||||
parser.add_argument('--total_frames', default=10000, type=int)
|
|
||||||
# replay buffer
|
# replay buffer
|
||||||
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
|
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
|
||||||
# train
|
# train
|
||||||
parser.add_argument('--agent', default='sac_ae', type=str)
|
parser.add_argument('--agent', default='sac_ae', type=str)
|
||||||
parser.add_argument('--init_steps', default=1000, type=int)
|
parser.add_argument('--init_steps', default=1000, type=int)
|
||||||
parser.add_argument('--num_train_steps', default=1000000, type=int)
|
parser.add_argument('--num_train_steps', default=1000000, type=int)
|
||||||
parser.add_argument('--batch_size', default=512, type=int)
|
parser.add_argument('--batch_size', default=128, type=int)
|
||||||
parser.add_argument('--hidden_dim', default=1024, type=int)
|
parser.add_argument('--hidden_dim', default=1024, type=int)
|
||||||
# eval
|
# eval
|
||||||
parser.add_argument('--eval_freq', default=10000, type=int)
|
parser.add_argument('--eval_freq', default=10000, type=int)
|
||||||
@ -146,10 +143,7 @@ def main():
|
|||||||
from_pixels=(args.encoder_type == 'pixel'),
|
from_pixels=(args.encoder_type == 'pixel'),
|
||||||
height=args.image_size,
|
height=args.image_size,
|
||||||
width=args.image_size,
|
width=args.image_size,
|
||||||
frame_skip=args.action_repeat,
|
frame_skip=args.action_repeat
|
||||||
img_source=args.img_source,
|
|
||||||
resource_files=args.resource_files,
|
|
||||||
total_frames=args.total_frames
|
|
||||||
)
|
)
|
||||||
env.seed(args.seed)
|
env.seed(args.seed)
|
||||||
|
|
||||||
@ -218,65 +212,28 @@ def main():
|
|||||||
|
|
||||||
L.log('train/episode', episode, step)
|
L.log('train/episode', episode, step)
|
||||||
|
|
||||||
if episode_step == 0:
|
|
||||||
last_obs = obs
|
|
||||||
# sample action for data collection
|
|
||||||
if step < args.init_steps:
|
|
||||||
last_action = env.action_space.sample()
|
|
||||||
else:
|
|
||||||
with utils.eval_mode(agent):
|
|
||||||
last_action = agent.sample_action(last_obs)
|
|
||||||
|
|
||||||
curr_obs, last_reward, last_done, _ = env.step(last_action)
|
|
||||||
|
|
||||||
# allow infinit bootstrap
|
|
||||||
last_done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(last_done)
|
|
||||||
episode_reward += last_reward
|
|
||||||
|
|
||||||
# sample action for data collection
|
# sample action for data collection
|
||||||
if step < args.init_steps:
|
if step < args.init_steps:
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
else:
|
else:
|
||||||
with utils.eval_mode(agent):
|
with utils.eval_mode(agent):
|
||||||
action = agent.sample_action(curr_obs)
|
action = agent.sample_action(obs)
|
||||||
|
|
||||||
next_obs, reward, done, _ = env.step(action)
|
|
||||||
|
|
||||||
# allow infinit bootstrap
|
|
||||||
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done)
|
|
||||||
episode_reward += reward
|
|
||||||
|
|
||||||
replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool)
|
|
||||||
|
|
||||||
last_obs = curr_obs
|
|
||||||
last_action = action
|
|
||||||
last_reward = reward
|
|
||||||
last_done = done
|
|
||||||
curr_obs = next_obs
|
|
||||||
|
|
||||||
# sample action for data collection
|
|
||||||
if step < args.init_steps:
|
|
||||||
action = env.action_space.sample()
|
|
||||||
else:
|
|
||||||
with utils.eval_mode(agent):
|
|
||||||
action = agent.sample_action(curr_obs)
|
|
||||||
|
|
||||||
|
|
||||||
# run training update
|
# run training update
|
||||||
if step >= args.init_steps:
|
if step >= args.init_steps:
|
||||||
#num_updates = args.init_steps if step == args.init_steps else 1
|
num_updates = args.init_steps if step == args.init_steps else 1
|
||||||
num_updates = 1 if step == args.init_steps else 1
|
|
||||||
for _ in range(num_updates):
|
for _ in range(num_updates):
|
||||||
agent.update(replay_buffer, L, step)
|
agent.update(replay_buffer, L, step)
|
||||||
|
|
||||||
next_obs, reward, done, _ = env.step(action)
|
next_obs, reward, done, _ = env.step(action)
|
||||||
|
|
||||||
# allow infinit bootstrap
|
# allow infinit bootstrap
|
||||||
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done)
|
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
|
||||||
|
done
|
||||||
|
)
|
||||||
episode_reward += reward
|
episode_reward += reward
|
||||||
|
|
||||||
#replay_buffer.add(obs, action, reward, next_obs, done_bool)
|
replay_buffer.add(obs, action, reward, next_obs, done_bool)
|
||||||
replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool)
|
|
||||||
|
|
||||||
obs = next_obs
|
obs = next_obs
|
||||||
episode_step += 1
|
episode_step += 1
|
||||||
|
Loading…
Reference in New Issue
Block a user