diff --git a/logger.py b/logger.py index 3a2adda..8e31fd4 100644 --- a/logger.py +++ b/logger.py @@ -7,7 +7,6 @@ import torch import torchvision import numpy as np from termcolor import colored -from datetime import datetime FORMAT_CONFIG = { 'rl': { @@ -94,10 +93,8 @@ class MetersGroup(object): class Logger(object): def __init__(self, log_dir, use_tb=True, config='rl'): self._log_dir = log_dir - now = datetime.now() - dt_string = now.strftime("%d_%m_%Y-%H_%M_%S") if use_tb: - tb_dir = os.path.join(log_dir, 'runs/tb_'+dt_string) + tb_dir = os.path.join(log_dir, 'tb') if os.path.exists(tb_dir): shutil.rmtree(tb_dir) self._sw = SummaryWriter(tb_dir) diff --git a/sac_ae.py b/sac_ae.py index 3499fac..0e5d915 100644 --- a/sac_ae.py +++ b/sac_ae.py @@ -6,7 +6,7 @@ import copy import math import utils -from encoder import make_encoder, club_loss, TransitionModel +from encoder import make_encoder from decoder import make_decoder LOG_FREQ = 10000 @@ -70,8 +70,10 @@ class Actor(nn.Module): self.outputs = dict() self.apply(weight_init) - def forward(self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False): - obs, _, _ = self.encoder(obs, detach=detach_encoder) + def forward( + self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False + ): + obs = self.encoder(obs, detach=detach_encoder) mu, log_std = self.trunk(obs).chunk(2, dim=-1) @@ -98,6 +100,7 @@ class Actor(nn.Module): log_pi = None mu, pi, log_pi = squash(mu, pi, log_pi) + return mu, pi, log_pi, log_std def log(self, L, step, log_freq=LOG_FREQ): @@ -156,7 +159,7 @@ class Critic(nn.Module): def forward(self, obs, action, detach_encoder=False): # detach_encoder allows to stop gradient propogation to encoder - obs, _ , _ = self.encoder(obs, detach=detach_encoder) + obs = self.encoder(obs, detach=detach_encoder) q1 = self.Q1(obs, action) q2 = self.Q2(obs, action) @@ -179,53 +182,7 @@ class Critic(nn.Module): L.log_param('train_critic/q1_fc%d' % i, self.Q1.trunk[i * 2], step) L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step) -class CURL(nn.Module): - """ - CURL - """ - def __init__(self, obs_shape, z_dim, a_dim, batch_size, critic, critic_target, output_type="continuous"): - super(CURL, self).__init__() - self.batch_size = batch_size - - self.encoder = critic.encoder - - self.encoder_target = critic_target.encoder - - self.W = nn.Parameter(torch.rand(z_dim, z_dim)) - self.combine = nn.Linear(z_dim + a_dim, z_dim) - self.output_type = output_type - - def encode(self, x, a=None, detach=False, ema=False): - """ - Encoder: z_t = e(x_t) - :param x: x_t, x y coordinates - :return: z_t, value in r2 - """ - if ema: - with torch.no_grad(): - z_out = self.encoder_target(x)[0] - z_out = self.combine(torch.concat((z_out,a), dim=-1)) - else: - z_out = self.encoder(x)[0] - - if detach: - z_out = z_out.detach() - return z_out - - def compute_logits(self, z_a, z_pos): - """ - Uses logits trick for CURL: - - compute (B,B) matrix z_a (W z_pos.T) - - positives are all diagonal elements - - negatives are all other elements - - to compute loss use multiclass cross entropy with identity matrix for labels - """ - Wz = torch.matmul(self.W, z_pos.T) # (z_dim,B) - logits = torch.matmul(z_a, Wz) # (B,B) - logits = logits - torch.max(logits, 1)[0][:, None] - return logits - class SacAeAgent(object): """SAC+AE algorithm.""" def __init__( @@ -267,12 +224,6 @@ class SacAeAgent(object): self.critic_target_update_freq = critic_target_update_freq self.decoder_update_freq = decoder_update_freq self.decoder_latent_lambda = decoder_latent_lambda - - self.transition_model = TransitionModel( - encoder_feature_dim, - hidden_dim, - action_shape[0], - encoder_feature_dim).to(device) self.actor = Actor( obs_shape, action_shape, hidden_dim, encoder_type, @@ -300,11 +251,6 @@ class SacAeAgent(object): # set target entropy to -|A| self.target_entropy = -np.prod(action_shape) - self.CURL = CURL(obs_shape, encoder_feature_dim, action_shape[0], - obs_shape[0], self.critic,self.critic_target, output_type='continuous').to(self.device) - - self.cross_entropy_loss = nn.CrossEntropyLoss() - self.decoder = None if decoder_type != 'identity': # create decoder @@ -335,10 +281,6 @@ class SacAeAgent(object): self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999) ) - self.cpc_optimizer = torch.optim.Adam( - self.CURL.parameters(), lr=encoder_lr - ) - self.log_alpha_optimizer = torch.optim.Adam( [self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999) ) @@ -387,6 +329,7 @@ class SacAeAgent(object): target_Q) + F.mse_loss(current_Q2, target_Q) L.log('train_critic/loss', critic_loss, step) + # Optimize the critic self.critic_optimizer.zero_grad() critic_loss.backward() @@ -423,38 +366,12 @@ class SacAeAgent(object): alpha_loss.backward() self.log_alpha_optimizer.step() - def update_decoder(self, last_obs, last_action, last_reward, curr_obs, last_not_done, action, reward, next_obs, not_done, target_obs, L, step): - h_curr, mu_h_curr, std_h_curr = self.critic.encoder(curr_obs) + def update_decoder(self, obs, target_obs, L, step): + h = self.critic.encoder(obs) - with torch.no_grad(): - h_last, _, _ = self.critic.encoder(last_obs) - self.transition_model.init_states(last_obs.shape[0], self.device) - curr_state = self.transition_model.transition_step(h_last, last_action, self.transition_model.prev_history, last_not_done) - - hist = curr_state["history"] - next_state = self.transition_model.transition_step(h_curr, action, hist, not_done) - - next_state_mu = next_state["mean"] - next_state_sigma = next_state["std"] - next_state_sample = next_state["sample"] - pred_dist = torch.distributions.Normal(next_state_mu, next_state_sigma) - - h, mu_h_next, logstd_h_next = self.critic.encoder(next_obs) - std_h_next = torch.exp(logstd_h_next) - enc_dist = torch.distributions.Normal(mu_h_next, std_h_next) - enc_loss = torch.mean(torch.distributions.kl.kl_divergence(enc_dist,pred_dist)) * 0.1 - - z_pos = self.CURL.encode(next_obs, action.detach(), ema=True) - logits = self.CURL.compute_logits(h_curr, z_pos) - labels = torch.arange(logits.shape[0]).long().to(self.device) - lb_loss = self.cross_entropy_loss(logits, labels) * 0.1 - - ub_loss = club_loss(h, mu_h_next, logstd_h_next, next_state_sample) * 0.1 - if target_obs.dim() == 4: # preprocess images to be in [-0.5, 0.5] range target_obs = utils.preprocess_obs(target_obs) - rec_obs = self.decoder(h) rec_loss = F.mse_loss(target_obs, rec_obs) @@ -462,35 +379,26 @@ class SacAeAgent(object): # see https://arxiv.org/pdf/1903.12436.pdf latent_loss = (0.5 * h.pow(2).sum(1)).mean() - loss = rec_loss + enc_loss + lb_loss + ub_loss #self.decoder_latent_lambda * latent_loss + loss = rec_loss + self.decoder_latent_lambda * latent_loss self.encoder_optimizer.zero_grad() self.decoder_optimizer.zero_grad() - self.cpc_optimizer.zero_grad() loss.backward() - - self.encoder_optimizer.step() + + self.encoder_optimizer.step() self.decoder_optimizer.step() - self.cpc_optimizer.step() L.log('train_ae/ae_loss', loss, step) - L.log('train_ae/lb_loss', lb_loss, step) - L.log('train_ae/ub_loss', ub_loss, step) - L.log('train_ae/enc_loss', enc_loss, step) - L.log('train_ae/dec_loss', rec_loss, step) self.decoder.log(L, step, log_freq=LOG_FREQ) def update(self, replay_buffer, L, step): - last_obs, last_action, last_reward, curr_obs, last_not_done, action, reward, next_obs, not_done = replay_buffer.sample() - #obs, action, reward, next_obs, not_done = replay_buffer.sample() + obs, action, reward, next_obs, not_done = replay_buffer.sample() - L.log('train/batch_reward', last_reward.mean(), step) + L.log('train/batch_reward', reward.mean(), step) - #self.update_critic(last_obs, last_action, last_reward, curr_obs, last_not_done, L, step) - self.update_critic(curr_obs, action, reward, next_obs, not_done, L, step) + self.update_critic(obs, action, reward, next_obs, not_done, L, step) if step % self.actor_update_freq == 0: - #self.update_actor_and_alpha(last_obs, L, step) - self.update_actor_and_alpha(curr_obs, L, step) + self.update_actor_and_alpha(obs, L, step) if step % self.critic_target_update_freq == 0: utils.soft_update_params( @@ -505,7 +413,7 @@ class SacAeAgent(object): ) if self.decoder is not None and step % self.decoder_update_freq == 0: - self.update_decoder(last_obs, last_action, last_reward, curr_obs, last_not_done, action, reward, next_obs, not_done, next_obs, L, step) + self.update_decoder(obs, obs, L, step) def save(self, model_dir, step): torch.save( diff --git a/train.py b/train.py index 73e9c1a..4f6cde4 100644 --- a/train.py +++ b/train.py @@ -26,16 +26,13 @@ def parse_args(): parser.add_argument('--image_size', default=84, type=int) parser.add_argument('--action_repeat', default=1, type=int) parser.add_argument('--frame_stack', default=3, type=int) - parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) - parser.add_argument('--resource_files', type=str) - parser.add_argument('--total_frames', default=10000, type=int) # replay buffer parser.add_argument('--replay_buffer_capacity', default=1000000, type=int) # train parser.add_argument('--agent', default='sac_ae', type=str) parser.add_argument('--init_steps', default=1000, type=int) parser.add_argument('--num_train_steps', default=1000000, type=int) - parser.add_argument('--batch_size', default=512, type=int) + parser.add_argument('--batch_size', default=128, type=int) parser.add_argument('--hidden_dim', default=1024, type=int) # eval parser.add_argument('--eval_freq', default=10000, type=int) @@ -146,10 +143,7 @@ def main(): from_pixels=(args.encoder_type == 'pixel'), height=args.image_size, width=args.image_size, - frame_skip=args.action_repeat, - img_source=args.img_source, - resource_files=args.resource_files, - total_frames=args.total_frames + frame_skip=args.action_repeat ) env.seed(args.seed) @@ -218,65 +212,28 @@ def main(): L.log('train/episode', episode, step) - if episode_step == 0: - last_obs = obs - # sample action for data collection - if step < args.init_steps: - last_action = env.action_space.sample() - else: - with utils.eval_mode(agent): - last_action = agent.sample_action(last_obs) - - curr_obs, last_reward, last_done, _ = env.step(last_action) - - # allow infinit bootstrap - last_done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(last_done) - episode_reward += last_reward - - # sample action for data collection - if step < args.init_steps: - action = env.action_space.sample() - else: - with utils.eval_mode(agent): - action = agent.sample_action(curr_obs) - - next_obs, reward, done, _ = env.step(action) - - # allow infinit bootstrap - done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done) - episode_reward += reward - - replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool) - - last_obs = curr_obs - last_action = action - last_reward = reward - last_done = done - curr_obs = next_obs - # sample action for data collection if step < args.init_steps: action = env.action_space.sample() else: with utils.eval_mode(agent): - action = agent.sample_action(curr_obs) + action = agent.sample_action(obs) - # run training update if step >= args.init_steps: - #num_updates = args.init_steps if step == args.init_steps else 1 - num_updates = 1 if step == args.init_steps else 1 + num_updates = args.init_steps if step == args.init_steps else 1 for _ in range(num_updates): agent.update(replay_buffer, L, step) next_obs, reward, done, _ = env.step(action) # allow infinit bootstrap - done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done) + done_bool = 0 if episode_step + 1 == env._max_episode_steps else float( + done + ) episode_reward += reward - #replay_buffer.add(obs, action, reward, next_obs, done_bool) - replay_buffer.add(last_obs, last_action, last_reward, curr_obs, last_done_bool, action, reward, next_obs, done_bool) + replay_buffer.add(obs, action, reward, next_obs, done_bool) obs = next_obs episode_step += 1 diff --git a/utils.py b/utils.py index 6eece02..067715c 100644 --- a/utils.py +++ b/utils.py @@ -75,26 +75,18 @@ class ReplayBuffer(object): # the proprioceptive obs is stored as float32, pixels obs as uint8 obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8 - self.last_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) - self.curr_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) + self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) - self.last_actions = np.empty((capacity, *action_shape), dtype=np.float32) self.actions = np.empty((capacity, *action_shape), dtype=np.float32) - self.last_rewards = np.empty((capacity, 1), dtype=np.float32) self.rewards = np.empty((capacity, 1), dtype=np.float32) - self.last_not_dones = np.empty((capacity, 1), dtype=np.float32) self.not_dones = np.empty((capacity, 1), dtype=np.float32) self.idx = 0 self.last_save = 0 self.full = False - def add(self, last_obs, last_action, last_reward, curr_obs, last_done, action, reward, next_obs, done): - np.copyto(self.last_obses[self.idx], last_obs) - np.copyto(self.last_actions[self.idx], last_action) - np.copyto(self.last_rewards[self.idx], last_reward) - np.copyto(self.curr_obses[self.idx], curr_obs) - np.copyto(self.last_not_dones[self.idx], not last_done) + def add(self, obs, action, reward, next_obs, done): + np.copyto(self.obses[self.idx], obs) np.copyto(self.actions[self.idx], action) np.copyto(self.rewards[self.idx], reward) np.copyto(self.next_obses[self.idx], next_obs) @@ -108,31 +100,25 @@ class ReplayBuffer(object): 0, self.capacity if self.full else self.idx, size=self.batch_size ) - last_obses = torch.as_tensor(self.last_obses[idxs], device=self.device).float() - last_actions = torch.as_tensor(self.last_actions[idxs], device=self.device) - last_rewards = torch.as_tensor(self.last_rewards[idxs], device=self.device) - curr_obses = torch.as_tensor(self.curr_obses[idxs], device=self.device).float() - last_not_dones = torch.as_tensor(self.last_not_dones[idxs], device=self.device) + obses = torch.as_tensor(self.obses[idxs], device=self.device).float() actions = torch.as_tensor(self.actions[idxs], device=self.device) rewards = torch.as_tensor(self.rewards[idxs], device=self.device) - next_obses = torch.as_tensor(self.next_obses[idxs], device=self.device).float() + next_obses = torch.as_tensor( + self.next_obses[idxs], device=self.device + ).float() not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device) - return last_obses, last_actions, last_rewards, curr_obses, last_not_dones, actions, rewards, next_obses, not_dones + return obses, actions, rewards, next_obses, not_dones def save(self, save_dir): if self.idx == self.last_save: return path = os.path.join(save_dir, '%d_%d.pt' % (self.last_save, self.idx)) payload = [ - self.last_obses[self.last_save:self.idx], - self.last_actions[self.last_save:self.idx], - self.last_rewards[self.last_save:self.idx], - self.curr_obses[self.last_save:self.idx], - self.last_not_dones[self.last_save:self.idx], + self.obses[self.last_save:self.idx], + self.next_obses[self.last_save:self.idx], self.actions[self.last_save:self.idx], self.rewards[self.last_save:self.idx], - self.next_obses[self.last_save:self.idx], self.not_dones[self.last_save:self.idx] ] self.last_save = self.idx @@ -146,14 +132,10 @@ class ReplayBuffer(object): path = os.path.join(save_dir, chunk) payload = torch.load(path) assert self.idx == start - self.last_obses[start:end] = payload[0] - self.last_actions[start:end] = payload[1] - self.last_rewards[start:end] = payload[2] - self.curr_obses[start:end] = payload[3] - self.last_not_dones[start:end] = payload[4] + self.obses[start:end] = payload[0] + self.next_obses[start:end] = payload[1] self.actions[start:end] = payload[2] self.rewards[start:end] = payload[3] - self.next_obses[start:end] = payload[4] self.not_dones[start:end] = payload[4] self.idx = end