Compare commits
5 Commits
82cee4746f
...
76442c02f5
Author | SHA1 | Date | |
---|---|---|---|
76442c02f5 | |||
|
7fa560e21c | ||
|
82ebe8cb05 | ||
|
27643e916d | ||
|
a86c6dfa82 |
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2020 Denis Yarats
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
12
README.md
12
README.md
@ -8,6 +8,18 @@ This is PyTorch implementation of SAC+AE from
|
||||
|
||||
[[Paper]](https://arxiv.org/abs/1910.01741) [[Webpage]](https://sites.google.com/view/sac-ae/home)
|
||||
|
||||
## Citation
|
||||
If you use this repo in your research, please consider citing the paper as follows
|
||||
```
|
||||
@article{yarats2019improving,
|
||||
title={Improving Sample Efficiency in Model-Free Reinforcement Learning from Images},
|
||||
author={Denis Yarats and Amy Zhang and Ilya Kostrikov and Brandon Amos and Joelle Pineau and Rob Fergus},
|
||||
year={2019},
|
||||
eprint={1910.01741},
|
||||
archivePrefix={arXiv}
|
||||
}
|
||||
```
|
||||
|
||||
## Requirements
|
||||
We assume you have access to a gpu that can run CUDA 9.2. Then, the simplest way to install all required dependencies is to create an anaconda environment by running:
|
||||
```
|
||||
|
@ -11,7 +11,7 @@ dependencies:
|
||||
- pip:
|
||||
- termcolor
|
||||
- git+git://github.com/deepmind/dm_control.git
|
||||
- git+git://github.com/1nadequacy/dmc2gym.git
|
||||
- git+git://github.com/denisyarats/dmc2gym.git
|
||||
- tb-nightly
|
||||
- imageio
|
||||
- imageio-ffmpeg
|
59
train.py
59
train.py
@ -26,13 +26,16 @@ def parse_args():
|
||||
parser.add_argument('--image_size', default=84, type=int)
|
||||
parser.add_argument('--action_repeat', default=1, type=int)
|
||||
parser.add_argument('--frame_stack', default=3, type=int)
|
||||
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
||||
parser.add_argument('--resource_files', type=str)
|
||||
parser.add_argument('--total_frames', default=10000, type=int)
|
||||
# replay buffer
|
||||
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
|
||||
# train
|
||||
parser.add_argument('--agent', default='sac_ae', type=str)
|
||||
parser.add_argument('--init_steps', default=1000, type=int)
|
||||
parser.add_argument('--num_train_steps', default=1000000, type=int)
|
||||
parser.add_argument('--batch_size', default=128, type=int)
|
||||
parser.add_argument('--batch_size', default=512, type=int)
|
||||
parser.add_argument('--hidden_dim', default=1024, type=int)
|
||||
# eval
|
||||
parser.add_argument('--eval_freq', default=10000, type=int)
|
||||
@ -143,7 +146,10 @@ def main():
|
||||
from_pixels=(args.encoder_type == 'pixel'),
|
||||
height=args.image_size,
|
||||
width=args.image_size,
|
||||
frame_skip=args.action_repeat
|
||||
frame_skip=args.action_repeat,
|
||||
img_source=args.img_source,
|
||||
resource_files=args.resource_files,
|
||||
total_frames=args.total_frames
|
||||
)
|
||||
env.seed(args.seed)
|
||||
|
||||
@ -212,28 +218,65 @@ def main():
|
||||
|
||||
L.log('train/episode', episode, step)
|
||||
|
||||
if episode_step == 0:
|
||||
last_obs = obs
|
||||
# sample action for data collection
|
||||
if step < args.init_steps:
|
||||
last_action = env.action_space.sample()
|
||||
else:
|
||||
with utils.eval_mode(agent):
|
||||
last_action = agent.sample_action(last_obs)
|
||||
|
||||
curr_obs, last_reward, last_done, _ = env.step(last_action)
|
||||
|
||||
# allow infinit bootstrap
|
||||
last_done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(last_done)
|
||||
episode_reward += last_reward
|
||||
|
||||
# sample action for data collection
|
||||
if step < args.init_steps:
|
||||
action = env.action_space.sample()
|
||||
else:
|
||||
with utils.eval_mode(agent):
|
||||
action = agent.sample_action(curr_obs)
|
||||
|
||||
next_obs, reward, done, _ = env.step(action)
|
||||
|
||||
# allow infinit bootstrap
|
||||
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done)
|
||||
episode_reward += reward
|
||||
|
||||
replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool)
|
||||
|
||||
last_obs = curr_obs
|
||||
last_action = action
|
||||
last_reward = reward
|
||||
last_done = done
|
||||
curr_obs = next_obs
|
||||
|
||||
# sample action for data collection
|
||||
if step < args.init_steps:
|
||||
action = env.action_space.sample()
|
||||
else:
|
||||
with utils.eval_mode(agent):
|
||||
action = agent.sample_action(obs)
|
||||
action = agent.sample_action(curr_obs)
|
||||
|
||||
|
||||
# run training update
|
||||
if step >= args.init_steps:
|
||||
num_updates = args.init_steps if step == args.init_steps else 1
|
||||
#num_updates = args.init_steps if step == args.init_steps else 1
|
||||
num_updates = 1 if step == args.init_steps else 1
|
||||
for _ in range(num_updates):
|
||||
agent.update(replay_buffer, L, step)
|
||||
|
||||
next_obs, reward, done, _ = env.step(action)
|
||||
|
||||
# allow infinit bootstrap
|
||||
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
|
||||
done
|
||||
)
|
||||
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done)
|
||||
episode_reward += reward
|
||||
|
||||
replay_buffer.add(obs, action, reward, next_obs, done_bool)
|
||||
#replay_buffer.add(obs, action, reward, next_obs, done_bool)
|
||||
replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool)
|
||||
|
||||
obs = next_obs
|
||||
episode_step += 1
|
||||
|
Loading…
Reference in New Issue
Block a user