Importing Training module

This commit is contained in:
Vedant Dave 2023-05-16 12:28:27 +02:00
parent 7fa560e21c
commit 76442c02f5

View File

@ -26,13 +26,16 @@ def parse_args():
parser.add_argument('--image_size', default=84, type=int) parser.add_argument('--image_size', default=84, type=int)
parser.add_argument('--action_repeat', default=1, type=int) parser.add_argument('--action_repeat', default=1, type=int)
parser.add_argument('--frame_stack', default=3, type=int) parser.add_argument('--frame_stack', default=3, type=int)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--resource_files', type=str)
parser.add_argument('--total_frames', default=10000, type=int)
# replay buffer # replay buffer
parser.add_argument('--replay_buffer_capacity', default=1000000, type=int) parser.add_argument('--replay_buffer_capacity', default=1000000, type=int)
# train # train
parser.add_argument('--agent', default='sac_ae', type=str) parser.add_argument('--agent', default='sac_ae', type=str)
parser.add_argument('--init_steps', default=1000, type=int) parser.add_argument('--init_steps', default=1000, type=int)
parser.add_argument('--num_train_steps', default=1000000, type=int) parser.add_argument('--num_train_steps', default=1000000, type=int)
parser.add_argument('--batch_size', default=128, type=int) parser.add_argument('--batch_size', default=512, type=int)
parser.add_argument('--hidden_dim', default=1024, type=int) parser.add_argument('--hidden_dim', default=1024, type=int)
# eval # eval
parser.add_argument('--eval_freq', default=10000, type=int) parser.add_argument('--eval_freq', default=10000, type=int)
@ -143,7 +146,10 @@ def main():
from_pixels=(args.encoder_type == 'pixel'), from_pixels=(args.encoder_type == 'pixel'),
height=args.image_size, height=args.image_size,
width=args.image_size, width=args.image_size,
frame_skip=args.action_repeat frame_skip=args.action_repeat,
img_source=args.img_source,
resource_files=args.resource_files,
total_frames=args.total_frames
) )
env.seed(args.seed) env.seed(args.seed)
@ -212,28 +218,65 @@ def main():
L.log('train/episode', episode, step) L.log('train/episode', episode, step)
if episode_step == 0:
last_obs = obs
# sample action for data collection
if step < args.init_steps:
last_action = env.action_space.sample()
else:
with utils.eval_mode(agent):
last_action = agent.sample_action(last_obs)
curr_obs, last_reward, last_done, _ = env.step(last_action)
# allow infinit bootstrap
last_done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(last_done)
episode_reward += last_reward
# sample action for data collection # sample action for data collection
if step < args.init_steps: if step < args.init_steps:
action = env.action_space.sample() action = env.action_space.sample()
else: else:
with utils.eval_mode(agent): with utils.eval_mode(agent):
action = agent.sample_action(obs) action = agent.sample_action(curr_obs)
next_obs, reward, done, _ = env.step(action)
# allow infinit bootstrap
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done)
episode_reward += reward
replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool)
last_obs = curr_obs
last_action = action
last_reward = reward
last_done = done
curr_obs = next_obs
# sample action for data collection
if step < args.init_steps:
action = env.action_space.sample()
else:
with utils.eval_mode(agent):
action = agent.sample_action(curr_obs)
# run training update # run training update
if step >= args.init_steps: if step >= args.init_steps:
num_updates = args.init_steps if step == args.init_steps else 1 #num_updates = args.init_steps if step == args.init_steps else 1
num_updates = 1 if step == args.init_steps else 1
for _ in range(num_updates): for _ in range(num_updates):
agent.update(replay_buffer, L, step) agent.update(replay_buffer, L, step)
next_obs, reward, done, _ = env.step(action) next_obs, reward, done, _ = env.step(action)
# allow infinit bootstrap # allow infinit bootstrap
done_bool = 0 if episode_step + 1 == env._max_episode_steps else float( done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done)
done
)
episode_reward += reward episode_reward += reward
replay_buffer.add(obs, action, reward, next_obs, done_bool) #replay_buffer.add(obs, action, reward, next_obs, done_bool)
replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool)
obs = next_obs obs = next_obs
episode_step += 1 episode_step += 1