diff --git a/train.py b/train.py index 4f6cde4..73e9c1a 100644 --- a/train.py +++ b/train.py @@ -26,13 +26,16 @@ def parse_args(): parser.add_argument('--image_size', default=84, type=int) parser.add_argument('--action_repeat', default=1, type=int) parser.add_argument('--frame_stack', default=3, type=int) + parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) + parser.add_argument('--resource_files', type=str) + parser.add_argument('--total_frames', default=10000, type=int) # replay buffer parser.add_argument('--replay_buffer_capacity', default=1000000, type=int) # train parser.add_argument('--agent', default='sac_ae', type=str) parser.add_argument('--init_steps', default=1000, type=int) parser.add_argument('--num_train_steps', default=1000000, type=int) - parser.add_argument('--batch_size', default=128, type=int) + parser.add_argument('--batch_size', default=512, type=int) parser.add_argument('--hidden_dim', default=1024, type=int) # eval parser.add_argument('--eval_freq', default=10000, type=int) @@ -143,7 +146,10 @@ def main(): from_pixels=(args.encoder_type == 'pixel'), height=args.image_size, width=args.image_size, - frame_skip=args.action_repeat + frame_skip=args.action_repeat, + img_source=args.img_source, + resource_files=args.resource_files, + total_frames=args.total_frames ) env.seed(args.seed) @@ -212,28 +218,65 @@ def main(): L.log('train/episode', episode, step) + if episode_step == 0: + last_obs = obs + # sample action for data collection + if step < args.init_steps: + last_action = env.action_space.sample() + else: + with utils.eval_mode(agent): + last_action = agent.sample_action(last_obs) + + curr_obs, last_reward, last_done, _ = env.step(last_action) + + # allow infinit bootstrap + last_done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(last_done) + episode_reward += last_reward + + # sample action for data collection + if step < args.init_steps: + action = env.action_space.sample() + else: + with utils.eval_mode(agent): + action = agent.sample_action(curr_obs) + + next_obs, reward, done, _ = env.step(action) + + # allow infinit bootstrap + done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done) + episode_reward += reward + + replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool) + + last_obs = curr_obs + last_action = action + last_reward = reward + last_done = done + curr_obs = next_obs + # sample action for data collection if step < args.init_steps: action = env.action_space.sample() else: with utils.eval_mode(agent): - action = agent.sample_action(obs) + action = agent.sample_action(curr_obs) + # run training update if step >= args.init_steps: - num_updates = args.init_steps if step == args.init_steps else 1 + #num_updates = args.init_steps if step == args.init_steps else 1 + num_updates = 1 if step == args.init_steps else 1 for _ in range(num_updates): agent.update(replay_buffer, L, step) next_obs, reward, done, _ = env.step(action) # allow infinit bootstrap - done_bool = 0 if episode_step + 1 == env._max_episode_steps else float( - done - ) + done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done) episode_reward += reward - replay_buffer.add(obs, action, reward, next_obs, done_bool) + #replay_buffer.add(obs, action, reward, next_obs, done_bool) + replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool) obs = next_obs episode_step += 1