Compare commits
6 Commits
0fc1b8bd37
...
4d57be91cd
Author | SHA1 | Date | |
---|---|---|---|
|
4d57be91cd | ||
|
5967b6d0cc | ||
|
7785d5f985 | ||
|
e6531aab7a | ||
|
15008bbd3b | ||
|
3f5320c78f |
13
README.md
13
README.md
@ -52,16 +52,19 @@ while an evaluation entry:
|
|||||||
```
|
```
|
||||||
which just tells the expected reward `ER` evaluating current policy after `S` steps. Note that `ER` is average evaluation performance over `num_eval_episodes` episodes (usually 10).
|
which just tells the expected reward `ER` evaluating current policy after `S` steps. Note that `ER` is average evaluation performance over `num_eval_episodes` episodes (usually 10).
|
||||||
|
|
||||||
|
### Running the natural video setting
|
||||||
|
You can download the Kinetics 400 dataset and grab the driving_car label from the train dataset to replicate our setup. Some instructions for downloading the dataset can be found here: https://github.com/Showmax/kinetics-downloader.
|
||||||
|
|
||||||
## CARLA
|
## CARLA
|
||||||
Download CARLA from https://github.com/carla-simulator/carla/releases, e.g.:
|
Download CARLA from https://github.com/carla-simulator/carla/releases, e.g.:
|
||||||
1. https://carla-releases.s3.eu-west-3.amazonaws.com/Linux/CARLA_0.9.8.tar.gz
|
1. https://carla-releases.s3.eu-west-3.amazonaws.com/Linux/CARLA_0.9.6.tar.gz
|
||||||
2. https://carla-releases.s3.eu-west-3.amazonaws.com/Linux/AdditionalMaps_0.9.8.tar.gz
|
2. https://carla-releases.s3.eu-west-3.amazonaws.com/Linux/AdditionalMaps_0.9.6.tar.gz
|
||||||
|
|
||||||
Add to your python path:
|
Add to your python path:
|
||||||
```
|
```
|
||||||
export PYTHONPATH=$PYTHONPATH:/home/rmcallister/code/bisim_metric/CARLA_0.9.8/PythonAPI
|
export PYTHONPATH=$PYTHONPATH:/home/rmcallister/code/bisim_metric/CARLA_0.9.6/PythonAPI
|
||||||
export PYTHONPATH=$PYTHONPATH:/home/rmcallister/code/bisim_metric/CARLA_0.9.8/PythonAPI/carla
|
export PYTHONPATH=$PYTHONPATH:/home/rmcallister/code/bisim_metric/CARLA_0.9.6/PythonAPI/carla
|
||||||
export PYTHONPATH=$PYTHONPATH:/home/rmcallister/code/bisim_metric/CARLA_0.9.8/PythonAPI/carla/dist/carla-0.9.8-py3.5-linux-x86_64.egg
|
export PYTHONPATH=$PYTHONPATH:/home/rmcallister/code/bisim_metric/CARLA_0.9.6/PythonAPI/carla/dist/carla-0.9.8-py3.5-linux-x86_64.egg
|
||||||
```
|
```
|
||||||
and merge the directories.
|
and merge the directories.
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ dependencies:
|
|||||||
- git+git://github.com/deepmind/dm_control.git
|
- git+git://github.com/deepmind/dm_control.git
|
||||||
- git+git://github.com/1nadequacy/dmc2gym.git
|
- git+git://github.com/1nadequacy/dmc2gym.git
|
||||||
- opencv-python
|
- opencv-python
|
||||||
- pillow=6.1
|
- pillow==6.1
|
||||||
- scikit-image
|
- scikit-image
|
||||||
- scikit-video
|
- scikit-video
|
||||||
- tb-nightly
|
- tb-nightly
|
||||||
|
28
train.py
28
train.py
@ -20,7 +20,7 @@ from video import VideoRecorder
|
|||||||
from agent.baseline_agent import BaselineAgent
|
from agent.baseline_agent import BaselineAgent
|
||||||
from agent.bisim_agent import BisimAgent
|
from agent.bisim_agent import BisimAgent
|
||||||
from agent.deepmdp_agent import DeepMDPAgent
|
from agent.deepmdp_agent import DeepMDPAgent
|
||||||
from agents.navigation.carla_env import CarlaEnv
|
#from agents.navigation.carla_env import CarlaEnv
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -34,14 +34,15 @@ def parse_args():
|
|||||||
parser.add_argument('--resource_files', type=str)
|
parser.add_argument('--resource_files', type=str)
|
||||||
parser.add_argument('--eval_resource_files', type=str)
|
parser.add_argument('--eval_resource_files', type=str)
|
||||||
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
||||||
parser.add_argument('--total_frames', default=1000, type=int)
|
parser.add_argument('--total_frames', default=100, type=int)
|
||||||
|
parser.add_argument('--high_noise', action='store_true')
|
||||||
# replay buffer
|
# replay buffer
|
||||||
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
|
parser.add_argument('--replay_buffer_capacity', default=50000, type=int)
|
||||||
# train
|
# train
|
||||||
parser.add_argument('--agent', default='bisim', type=str, choices=['baseline', 'bisim', 'deepmdp'])
|
parser.add_argument('--agent', default='bisim', type=str, choices=['baseline', 'bisim', 'deepmdp'])
|
||||||
parser.add_argument('--init_steps', default=1000, type=int)
|
parser.add_argument('--init_steps', default=1000, type=int)
|
||||||
parser.add_argument('--num_train_steps', default=1000000, type=int)
|
parser.add_argument('--num_train_steps', default=1000050, type=int)
|
||||||
parser.add_argument('--batch_size', default=512, type=int)
|
parser.add_argument('--batch_size', default=128, type=int) # 512
|
||||||
parser.add_argument('--hidden_dim', default=256, type=int)
|
parser.add_argument('--hidden_dim', default=256, type=int)
|
||||||
parser.add_argument('--k', default=3, type=int, help='number of steps for inverse model')
|
parser.add_argument('--k', default=3, type=int, help='number of steps for inverse model')
|
||||||
parser.add_argument('--bisim_coef', default=0.5, type=float, help='coefficient for bisim terms')
|
parser.add_argument('--bisim_coef', default=0.5, type=float, help='coefficient for bisim terms')
|
||||||
@ -50,12 +51,12 @@ def parse_args():
|
|||||||
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
|
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
|
||||||
parser.add_argument('--num_eval_episodes', default=20, type=int)
|
parser.add_argument('--num_eval_episodes', default=20, type=int)
|
||||||
# critic
|
# critic
|
||||||
parser.add_argument('--critic_lr', default=1e-3, type=float)
|
parser.add_argument('--critic_lr', default=1e-5, type=float)
|
||||||
parser.add_argument('--critic_beta', default=0.9, type=float)
|
parser.add_argument('--critic_beta', default=0.9, type=float)
|
||||||
parser.add_argument('--critic_tau', default=0.005, type=float)
|
parser.add_argument('--critic_tau', default=0.005, type=float)
|
||||||
parser.add_argument('--critic_target_update_freq', default=2, type=int)
|
parser.add_argument('--critic_target_update_freq', default=2, type=int)
|
||||||
# actor
|
# actor
|
||||||
parser.add_argument('--actor_lr', default=1e-3, type=float)
|
parser.add_argument('--actor_lr', default=1e-5, type=float)
|
||||||
parser.add_argument('--actor_beta', default=0.9, type=float)
|
parser.add_argument('--actor_beta', default=0.9, type=float)
|
||||||
parser.add_argument('--actor_log_std_min', default=-10, type=float)
|
parser.add_argument('--actor_log_std_min', default=-10, type=float)
|
||||||
parser.add_argument('--actor_log_std_max', default=2, type=float)
|
parser.add_argument('--actor_log_std_max', default=2, type=float)
|
||||||
@ -63,19 +64,19 @@ def parse_args():
|
|||||||
# encoder/decoder
|
# encoder/decoder
|
||||||
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
|
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
|
||||||
parser.add_argument('--encoder_feature_dim', default=50, type=int)
|
parser.add_argument('--encoder_feature_dim', default=50, type=int)
|
||||||
parser.add_argument('--encoder_lr', default=1e-3, type=float)
|
parser.add_argument('--encoder_lr', default=1e-5, type=float)
|
||||||
parser.add_argument('--encoder_tau', default=0.005, type=float)
|
parser.add_argument('--encoder_tau', default=0.005, type=float)
|
||||||
parser.add_argument('--encoder_stride', default=1, type=int)
|
parser.add_argument('--encoder_stride', default=1, type=int)
|
||||||
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
|
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
|
||||||
parser.add_argument('--decoder_lr', default=1e-3, type=float)
|
parser.add_argument('--decoder_lr', default=1e-5, type=float)
|
||||||
parser.add_argument('--decoder_update_freq', default=1, type=int)
|
parser.add_argument('--decoder_update_freq', default=1, type=int)
|
||||||
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
|
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
|
||||||
parser.add_argument('--num_layers', default=4, type=int)
|
parser.add_argument('--num_layers', default=4, type=int)
|
||||||
parser.add_argument('--num_filters', default=32, type=int)
|
parser.add_argument('--num_filters', default=32, type=int)
|
||||||
# sac
|
# sac
|
||||||
parser.add_argument('--discount', default=0.99, type=float)
|
parser.add_argument('--discount', default=0.99, type=float)
|
||||||
parser.add_argument('--init_temperature', default=0.01, type=float)
|
parser.add_argument('--init_temperature', default=0.1, type=float)
|
||||||
parser.add_argument('--alpha_lr', default=1e-3, type=float)
|
parser.add_argument('--alpha_lr', default=1e-4, type=float)
|
||||||
parser.add_argument('--alpha_beta', default=0.9, type=float)
|
parser.add_argument('--alpha_beta', default=0.9, type=float)
|
||||||
# misc
|
# misc
|
||||||
parser.add_argument('--seed', default=1, type=int)
|
parser.add_argument('--seed', default=1, type=int)
|
||||||
@ -88,6 +89,9 @@ def parse_args():
|
|||||||
parser.add_argument('--render', default=False, action='store_true')
|
parser.add_argument('--render', default=False, action='store_true')
|
||||||
parser.add_argument('--port', default=2000, type=int)
|
parser.add_argument('--port', default=2000, type=int)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
#from dmc2gym.wrappers import set_global_var
|
||||||
|
#set_global_var(args.high_noise)
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
@ -318,7 +322,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# stack several consecutive frames together
|
# stack several consecutive frames together
|
||||||
if args.encoder_type == 'pixel':
|
if args.encoder_type.startswith('pixel'):
|
||||||
env = utils.FrameStack(env, k=args.frame_stack)
|
env = utils.FrameStack(env, k=args.frame_stack)
|
||||||
eval_env = utils.FrameStack(eval_env, k=args.frame_stack)
|
eval_env = utils.FrameStack(eval_env, k=args.frame_stack)
|
||||||
|
|
||||||
|
2
video.py
2
video.py
@ -22,7 +22,7 @@ class VideoRecorder(object):
|
|||||||
self.frames = []
|
self.frames = []
|
||||||
if resource_files:
|
if resource_files:
|
||||||
files = glob.glob(os.path.expanduser(resource_files))
|
files = glob.glob(os.path.expanduser(resource_files))
|
||||||
self._bg_source = RandomVideoSource((height, width), files, grayscale=False, total_frames=1000)
|
self._bg_source = RandomVideoSource((height, width), files, grayscale=False, max_videos=50, random_bg=False)
|
||||||
else:
|
else:
|
||||||
self._bg_source = None
|
self._bg_source = None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user