Importing Training module
This commit is contained in:
parent
7fa560e21c
commit
76442c02f5
59
train.py
59
train.py
@ -26,13 +26,16 @@ def parse_args():
|
|||||||
parser.add_argument('--image_size', default=84, type=int)
|
parser.add_argument('--image_size', default=84, type=int)
|
||||||
parser.add_argument('--action_repeat', default=1, type=int)
|
parser.add_argument('--action_repeat', default=1, type=int)
|
||||||
parser.add_argument('--frame_stack', default=3, type=int)
|
parser.add_argument('--frame_stack', default=3, type=int)
|
||||||
|
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
||||||
|
parser.add_argument('--resource_files', type=str)
|
||||||
|
parser.add_argument('--total_frames', default=10000, type=int)
|
||||||
# replay buffer
|
# replay buffer
|
||||||
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
|
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
|
||||||
# train
|
# train
|
||||||
parser.add_argument('--agent', default='sac_ae', type=str)
|
parser.add_argument('--agent', default='sac_ae', type=str)
|
||||||
parser.add_argument('--init_steps', default=1000, type=int)
|
parser.add_argument('--init_steps', default=1000, type=int)
|
||||||
parser.add_argument('--num_train_steps', default=1000000, type=int)
|
parser.add_argument('--num_train_steps', default=1000000, type=int)
|
||||||
parser.add_argument('--batch_size', default=128, type=int)
|
parser.add_argument('--batch_size', default=512, type=int)
|
||||||
parser.add_argument('--hidden_dim', default=1024, type=int)
|
parser.add_argument('--hidden_dim', default=1024, type=int)
|
||||||
# eval
|
# eval
|
||||||
parser.add_argument('--eval_freq', default=10000, type=int)
|
parser.add_argument('--eval_freq', default=10000, type=int)
|
||||||
@ -143,7 +146,10 @@ def main():
|
|||||||
from_pixels=(args.encoder_type == 'pixel'),
|
from_pixels=(args.encoder_type == 'pixel'),
|
||||||
height=args.image_size,
|
height=args.image_size,
|
||||||
width=args.image_size,
|
width=args.image_size,
|
||||||
frame_skip=args.action_repeat
|
frame_skip=args.action_repeat,
|
||||||
|
img_source=args.img_source,
|
||||||
|
resource_files=args.resource_files,
|
||||||
|
total_frames=args.total_frames
|
||||||
)
|
)
|
||||||
env.seed(args.seed)
|
env.seed(args.seed)
|
||||||
|
|
||||||
@ -212,28 +218,65 @@ def main():
|
|||||||
|
|
||||||
L.log('train/episode', episode, step)
|
L.log('train/episode', episode, step)
|
||||||
|
|
||||||
|
if episode_step == 0:
|
||||||
|
last_obs = obs
|
||||||
|
# sample action for data collection
|
||||||
|
if step < args.init_steps:
|
||||||
|
last_action = env.action_space.sample()
|
||||||
|
else:
|
||||||
|
with utils.eval_mode(agent):
|
||||||
|
last_action = agent.sample_action(last_obs)
|
||||||
|
|
||||||
|
curr_obs, last_reward, last_done, _ = env.step(last_action)
|
||||||
|
|
||||||
|
# allow infinit bootstrap
|
||||||
|
last_done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(last_done)
|
||||||
|
episode_reward += last_reward
|
||||||
|
|
||||||
|
# sample action for data collection
|
||||||
|
if step < args.init_steps:
|
||||||
|
action = env.action_space.sample()
|
||||||
|
else:
|
||||||
|
with utils.eval_mode(agent):
|
||||||
|
action = agent.sample_action(curr_obs)
|
||||||
|
|
||||||
|
next_obs, reward, done, _ = env.step(action)
|
||||||
|
|
||||||
|
# allow infinit bootstrap
|
||||||
|
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done)
|
||||||
|
episode_reward += reward
|
||||||
|
|
||||||
|
replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool)
|
||||||
|
|
||||||
|
last_obs = curr_obs
|
||||||
|
last_action = action
|
||||||
|
last_reward = reward
|
||||||
|
last_done = done
|
||||||
|
curr_obs = next_obs
|
||||||
|
|
||||||
# sample action for data collection
|
# sample action for data collection
|
||||||
if step < args.init_steps:
|
if step < args.init_steps:
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
else:
|
else:
|
||||||
with utils.eval_mode(agent):
|
with utils.eval_mode(agent):
|
||||||
action = agent.sample_action(obs)
|
action = agent.sample_action(curr_obs)
|
||||||
|
|
||||||
|
|
||||||
# run training update
|
# run training update
|
||||||
if step >= args.init_steps:
|
if step >= args.init_steps:
|
||||||
num_updates = args.init_steps if step == args.init_steps else 1
|
#num_updates = args.init_steps if step == args.init_steps else 1
|
||||||
|
num_updates = 1 if step == args.init_steps else 1
|
||||||
for _ in range(num_updates):
|
for _ in range(num_updates):
|
||||||
agent.update(replay_buffer, L, step)
|
agent.update(replay_buffer, L, step)
|
||||||
|
|
||||||
next_obs, reward, done, _ = env.step(action)
|
next_obs, reward, done, _ = env.step(action)
|
||||||
|
|
||||||
# allow infinit bootstrap
|
# allow infinit bootstrap
|
||||||
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
|
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done)
|
||||||
done
|
|
||||||
)
|
|
||||||
episode_reward += reward
|
episode_reward += reward
|
||||||
|
|
||||||
replay_buffer.add(obs, action, reward, next_obs, done_bool)
|
#replay_buffer.add(obs, action, reward, next_obs, done_bool)
|
||||||
|
replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool)
|
||||||
|
|
||||||
obs = next_obs
|
obs = next_obs
|
||||||
episode_step += 1
|
episode_step += 1
|
||||||
|
Loading…
Reference in New Issue
Block a user