Importing Training module
This commit is contained in:
parent
7fa560e21c
commit
76442c02f5
59
train.py
59
train.py
@ -26,13 +26,16 @@ def parse_args():
|
||||
parser.add_argument('--image_size', default=84, type=int)
|
||||
parser.add_argument('--action_repeat', default=1, type=int)
|
||||
parser.add_argument('--frame_stack', default=3, type=int)
|
||||
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
||||
parser.add_argument('--resource_files', type=str)
|
||||
parser.add_argument('--total_frames', default=10000, type=int)
|
||||
# replay buffer
|
||||
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
|
||||
# train
|
||||
parser.add_argument('--agent', default='sac_ae', type=str)
|
||||
parser.add_argument('--init_steps', default=1000, type=int)
|
||||
parser.add_argument('--num_train_steps', default=1000000, type=int)
|
||||
parser.add_argument('--batch_size', default=128, type=int)
|
||||
parser.add_argument('--batch_size', default=512, type=int)
|
||||
parser.add_argument('--hidden_dim', default=1024, type=int)
|
||||
# eval
|
||||
parser.add_argument('--eval_freq', default=10000, type=int)
|
||||
@ -143,7 +146,10 @@ def main():
|
||||
from_pixels=(args.encoder_type == 'pixel'),
|
||||
height=args.image_size,
|
||||
width=args.image_size,
|
||||
frame_skip=args.action_repeat
|
||||
frame_skip=args.action_repeat,
|
||||
img_source=args.img_source,
|
||||
resource_files=args.resource_files,
|
||||
total_frames=args.total_frames
|
||||
)
|
||||
env.seed(args.seed)
|
||||
|
||||
@ -212,28 +218,65 @@ def main():
|
||||
|
||||
L.log('train/episode', episode, step)
|
||||
|
||||
if episode_step == 0:
|
||||
last_obs = obs
|
||||
# sample action for data collection
|
||||
if step < args.init_steps:
|
||||
last_action = env.action_space.sample()
|
||||
else:
|
||||
with utils.eval_mode(agent):
|
||||
last_action = agent.sample_action(last_obs)
|
||||
|
||||
curr_obs, last_reward, last_done, _ = env.step(last_action)
|
||||
|
||||
# allow infinit bootstrap
|
||||
last_done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(last_done)
|
||||
episode_reward += last_reward
|
||||
|
||||
# sample action for data collection
|
||||
if step < args.init_steps:
|
||||
action = env.action_space.sample()
|
||||
else:
|
||||
with utils.eval_mode(agent):
|
||||
action = agent.sample_action(curr_obs)
|
||||
|
||||
next_obs, reward, done, _ = env.step(action)
|
||||
|
||||
# allow infinit bootstrap
|
||||
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done)
|
||||
episode_reward += reward
|
||||
|
||||
replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool)
|
||||
|
||||
last_obs = curr_obs
|
||||
last_action = action
|
||||
last_reward = reward
|
||||
last_done = done
|
||||
curr_obs = next_obs
|
||||
|
||||
# sample action for data collection
|
||||
if step < args.init_steps:
|
||||
action = env.action_space.sample()
|
||||
else:
|
||||
with utils.eval_mode(agent):
|
||||
action = agent.sample_action(obs)
|
||||
action = agent.sample_action(curr_obs)
|
||||
|
||||
|
||||
# run training update
|
||||
if step >= args.init_steps:
|
||||
num_updates = args.init_steps if step == args.init_steps else 1
|
||||
#num_updates = args.init_steps if step == args.init_steps else 1
|
||||
num_updates = 1 if step == args.init_steps else 1
|
||||
for _ in range(num_updates):
|
||||
agent.update(replay_buffer, L, step)
|
||||
|
||||
next_obs, reward, done, _ = env.step(action)
|
||||
|
||||
# allow infinit bootstrap
|
||||
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
|
||||
done
|
||||
)
|
||||
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done)
|
||||
episode_reward += reward
|
||||
|
||||
replay_buffer.add(obs, action, reward, next_obs, done_bool)
|
||||
#replay_buffer.add(obs, action, reward, next_obs, done_bool)
|
||||
replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool)
|
||||
|
||||
obs = next_obs
|
||||
episode_step += 1
|
||||
|
Loading…
Reference in New Issue
Block a user