import numpy as np import torch import argparse import os import gym import time import json import dmc2gym import copy import tqdm import wandb import utils from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample from logger import Logger from video import VideoRecorder from dmc2gym.wrappers import set_global_var import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T from torch.utils.tensorboard import SummaryWriter #from agent.baseline_agent import BaselineAgent #from agent.bisim_agent import BisimAgent #from agent.deepmdp_agent import DeepMDPAgent #from agents.navigation.carla_env import CarlaEnv def parse_args(): parser = argparse.ArgumentParser() # environment parser.add_argument('--domain_name', default='cheetah') parser.add_argument('--version', default=1, type=int) parser.add_argument('--task_name', default='run') parser.add_argument('--image_size', default=84, type=int) parser.add_argument('--channels', default=3, type=int) parser.add_argument('--action_repeat', default=1, type=int) parser.add_argument('--frame_stack', default=3, type=int) parser.add_argument('--resource_files', type=str) parser.add_argument('--eval_resource_files', type=str) parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) parser.add_argument('--total_frames', default=1000, type=int) # 10000 parser.add_argument('--high_noise', action='store_true') # replay buffer parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #50000 parser.add_argument('--episode_length', default=51, type=int) # train parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc']) parser.add_argument('--init_steps', default=10000, type=int) parser.add_argument('--num_train_steps', default=10000, type=int) parser.add_argument('--batch_size', default=20, type=int) #512 parser.add_argument('--state_size', default=256, type=int) parser.add_argument('--hidden_size', default=128, type=int) parser.add_argument('--history_size', default=128, type=int) parser.add_argument('--num-units', type=int, default=200, help='num hidden units for reward/value/discount models') parser.add_argument('--load_encoder', default=None, type=str) parser.add_argument('--imagine_horizon', default=15, type=str) parser.add_argument('--grad_clip_norm', type=float, default=100.0, help='Gradient clipping norm') # eval parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000 parser.add_argument('--num_eval_episodes', default=20, type=int) # value parser.add_argument('--value_lr', default=1e-4, type=float) parser.add_argument('--value_beta', default=0.9, type=float) parser.add_argument('--value_tau', default=0.005, type=float) parser.add_argument('--value_target_update_freq', default=100, type=int) parser.add_argument('--td_lambda', default=0.95, type=int) # reward parser.add_argument('--reward_lr', default=1e-4, type=float) # actor parser.add_argument('--actor_lr', default=1e-4, type=float) parser.add_argument('--actor_beta', default=0.9, type=float) parser.add_argument('--actor_log_std_min', default=-10, type=float) parser.add_argument('--actor_log_std_max', default=2, type=float) parser.add_argument('--actor_update_freq', default=2, type=int) # world/encoder/decoder parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity']) parser.add_argument('--encoder_feature_dim', default=50, type=int) parser.add_argument('--world_model_lr', default=1e-3, type=float) parser.add_argument('--past_transition_lr', default=1e-3, type=float) parser.add_argument('--encoder_lr', default=1e-3, type=float) parser.add_argument('--encoder_tau', default=0.001, type=float) parser.add_argument('--encoder_stride', default=1, type=int) parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction']) parser.add_argument('--decoder_lr', default=1e-3, type=float) parser.add_argument('--decoder_update_freq', default=1, type=int) parser.add_argument('--decoder_weight_lambda', default=0.0, type=float) parser.add_argument('--num_layers', default=4, type=int) parser.add_argument('--num_filters', default=32, type=int) parser.add_argument('--aug', action='store_true') # sac parser.add_argument('--discount', default=0.99, type=float) parser.add_argument('--init_temperature', default=0.01, type=float) parser.add_argument('--alpha_lr', default=1e-3, type=float) parser.add_argument('--alpha_beta', default=0.9, type=float) # misc parser.add_argument('--seed', default=1, type=int) parser.add_argument('--logging_freq', default=100, type=int) parser.add_argument('--work_dir', default='.', type=str) parser.add_argument('--save_tb', default=False, action='store_true') parser.add_argument('--save_model', default=False, action='store_true') parser.add_argument('--save_buffer', default=False, action='store_true') parser.add_argument('--save_video', default=False, action='store_true') parser.add_argument('--transition_model_type', default='', type=str, choices=['', 'deterministic', 'probabilistic', 'ensemble']) parser.add_argument('--render', default=False, action='store_true') parser.add_argument('--port', default=2000, type=int) parser.add_argument('--num_likelihood_updates', default=5, type=int) args = parser.parse_args() return args class DPI: def __init__(self, args, writer): # wandb config #run = wandb.init(project="dpi") self.args = args # set environment noise set_global_var(self.args.high_noise) # environment setup self.env = make_env(self.args) #self.args.seed = np.random.randint(0, 1000) self.env.seed(self.args.seed) # noiseless environment setup self.args.version = 2 # env_id changes to v2 self.args.img_source = None # no image noise self.args.resource_files = None # stack several consecutive frames together if self.args.encoder_type.startswith('pixel'): self.env = utils.FrameStack(self.env, k=self.args.frame_stack) # create replay buffer self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity, obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), action_size=self.env.action_space.shape[0], seq_len=self.args.episode_length, batch_size=args.batch_size, args=self.args) # create work directory utils.make_dir(self.args.work_dir) self.video_dir = utils.make_dir(os.path.join(self.args.work_dir, 'video')) self.model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model')) self.buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer')) # create models self.build_models(use_saved=False, saved_model_dir=self.model_dir) def build_models(self, use_saved, saved_model_dir=None): # World Models self.obs_encoder = ObservationEncoder( obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84) state_size=self.args.state_size # 128 ) self.obs_encoder_momentum = ObservationEncoder( obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84) state_size=self.args.state_size # 128 ) self.obs_decoder = ObservationDecoder( state_size=self.args.state_size, # 128 output_shape=(self.args.channels,self.args.image_size,self.args.image_size) # (3,84,84) ) self.transition_model = TransitionModel( state_size=self.args.state_size, # 128 hidden_size=self.args.hidden_size, # 256 action_size=self.env.action_space.shape[0], # 6 history_size=self.args.history_size, # 128 ) # Actor Model self.actor_model = Actor( state_size=self.args.state_size, # 128 hidden_size=self.args.hidden_size, # 256, action_size=self.env.action_space.shape[0], # 6 ) # Value Models self.value_model = ValueModel( state_size=self.args.state_size, # 128 hidden_size=self.args.hidden_size, # 256 ) self.target_value_model = ValueModel( state_size=self.args.state_size, # 128 hidden_size=self.args.hidden_size, # 256 ) self.reward_model = RewardModel( state_size=self.args.state_size, # 128 hidden_size=self.args.hidden_size, # 256 ) # Contrastive Models self.prjoection_head = ProjectionHead( state_size=self.args.state_size, # 128 action_size=self.env.action_space.shape[0], # 6 hidden_size=self.args.hidden_size, # 256 ) self.prjoection_head_momentum = ProjectionHead( state_size=self.args.state_size, # 128 action_size=self.env.action_space.shape[0], # 6 hidden_size=self.args.hidden_size, # 256 ) self.contrastive_head = ContrastiveHead( hidden_size=self.args.hidden_size, # 256 ) # model parameters self.world_model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_decoder.parameters()) + \ list(self.value_model.parameters()) + list(self.transition_model.parameters()) + \ list(self.prjoection_head.parameters()) self.past_transition_parameters = self.transition_model.parameters() # optimizers self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr) self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr) self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr) self.past_transition_opt = torch.optim.Adam(self.past_transition_parameters, self.args.past_transition_lr) # Create Modules self.world_model_modules = [self.obs_encoder, self.obs_decoder, self.value_model, self.transition_model, self.prjoection_head] self.value_modules = [self.value_model] self.actor_modules = [self.actor_model] if use_saved: self._use_saved_models(saved_model_dir) def _use_saved_models(self, saved_model_dir): self.obs_encoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_encoder.pt'))) self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt'))) self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt'))) def collect_sequences(self, episodes): obs = self.env.reset() done = False #video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files) for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'): if args.save_video: self.env.video.init(enabled=True) for i in range(self.args.episode_length): action = self.env.action_space.sample() next_obs, rew, done, _ = self.env.step(action) self.data_buffer.add(obs, action, next_obs, rew, episode_count+1, done) if args.save_video: self.env.video.record(self.env) if done or i == self.args.episode_length-1: obs = self.env.reset() done=False else: obs = next_obs if args.save_video: self.env.video.save('noisy/%d.mp4' % episode_count) print("Collected {} random episodes".format(episode_count+1)) def train(self): # collect experience self.collect_sequences(self.args.batch_size) # Group observations and next_observations by steps from past to present last_observations = torch.tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float()[:self.args.episode_length-1] current_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()[:self.args.episode_length-1] next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()[1:] actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[:self.args.episode_length-1] next_actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[1:] rewards = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"rewards",obs=False)).float()[1:] # Preprocessing last_observations = preprocess_obs(last_observations) current_observations = preprocess_obs(current_observations) next_observations = preprocess_obs(next_observations) # Initialize transition model states self.transition_model.init_states(self.args.batch_size, device="cpu") # (N,128) self.history = self.transition_model.prev_history # (N,128) # Train encoder step = 0 total_steps = 10000 metrics = {} while step < total_steps: for i in range(self.args.episode_length-1): if i > 0: # Encode observations and next_observations self.last_states_dict = self.get_features(last_observations[i]) self.current_states_dict = self.get_features(current_observations[i]) self.next_states_dict = self.get_features(next_observations[i], momentum=True) self.action = actions[i] # (N,6) self.next_action = next_actions[i] # (N,6) history = self.transition_model.prev_history # Encode negative observations idx = torch.randperm(current_observations[i].shape[0]) # random permutation on batch random_time_index = torch.randint(0, self.args.episode_length-2, (1,)).item() # random time index negative_current_observations = current_observations[random_time_index][idx] self.negative_current_states_dict = self.obs_encoder(negative_current_observations) # Predict current state from past state with transition model last_states_sample = self.last_states_dict["sample"] predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history) self.history = predicted_current_state_dict["history"] # Calculate upper bound loss likeli_loss, ub_loss = self._upper_bound_minimization(self.last_states_dict, self.current_states_dict, self.negative_current_states_dict, predicted_current_state_dict ) #likeli_loss = torch.tensor(likeli_loss.numpy(),dtype=torch.float32, requires_grad=True) #ikeli_loss = likeli_loss.mean() # Calculate encoder loss encoder_loss = self._past_encoder_loss(self.current_states_dict, predicted_current_state_dict) #total_ub_loss += ub_loss #total_encoder_loss += encoder_loss # contrastive projection vec_anchor = predicted_current_state_dict["sample"] vec_positive = self.next_states_dict["sample"].detach() z_anchor = self.prjoection_head(vec_anchor, self.action) z_positive = self.prjoection_head_momentum(vec_positive, next_actions[i]).detach() # contrastive loss logits = self.contrastive_head(z_anchor, z_positive) labels = labels = torch.arange(logits.shape[0]).long() lb_loss = F.cross_entropy(logits, labels) # behaviour learning with FreezeParameters(self.world_model_modules): imagine_horizon = self.args.imagine_horizon #np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i) imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"].detach(), self.next_action, self.history.detach(), imagine_horizon) # decoder loss horizon = np.minimum(50-i, imagine_horizon) obs_dist = self.obs_decoder(imagined_rollout["sample"][:horizon]) decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon][:,:,:3,:,:])) # reward loss reward_dist = self.reward_model(self.current_states_dict["sample"]) reward_loss = -torch.mean(reward_dist.log_prob(rewards[:-1])) # update models world_model_loss = encoder_loss + ub_loss + lb_loss + decoder_loss * 1e-2 self.world_model_opt.zero_grad() world_model_loss.backward() nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm) self.world_model_opt.step() # actor loss with FreezeParameters(self.world_model_modules + self.value_modules): imag_rew_dist = self.reward_model(imagined_rollout["sample"]) target_imag_val_dist = self.target_value_model(imagined_rollout["sample"]) imag_rews = imag_rew_dist.mean target_imag_vals = target_imag_val_dist.mean discounts = self.args.discount * torch.ones_like(imag_rews).detach() self.target_returns = self._compute_lambda_return(imag_rews[:-1], target_imag_vals[:-1], discounts[:-1] , self.args.td_lambda, target_imag_vals[-1]) discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0) self.discounts = torch.cumprod(discounts, 0).detach() actor_loss = -torch.mean(self.discounts * self.target_returns) # update actor self.actor_opt.zero_grad() actor_loss.backward() nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm) self.actor_opt.step() # value loss with torch.no_grad(): value_feat = imagined_rollout["sample"][:-1].detach() value_targ = self.target_returns.detach() value_dist = self.value_model(value_feat) value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1)) # update value self.value_opt.zero_grad() value_loss.backward() nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm) self.value_opt.step() # update target value if step % self.args.value_target_update_freq == 0: self.target_value_model = copy.deepcopy(self.value_model) # update momentum encoder soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau) # update momentum projection head soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau) step += 1 if step % self.args.logging_freq: writer.add_scalar('Main Loss/World Loss', world_model_loss, step) writer.add_scalar('Main Models Loss/Encoder Loss', encoder_loss, step) writer.add_scalar('Main Models Loss/Decoder Loss', decoder_loss, step) writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss, step) writer.add_scalar('Actor Critic Loss/Value Loss', value_loss, step) writer.add_scalar('Actor Critic Loss/Reward Loss', reward_loss, step) writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss, step) writer.add_scalar('Bound Loss/Lower Bound Loss', lb_loss, step) """ if step % self.args.logging_freq: metrics['Upper Bound Loss'] = ub_loss.item() metrics['Encoder Loss'] = encoder_loss.item() metrics['Decoder Loss'] = decoder_loss.item() metrics["Lower Bound Loss"] = lb_loss.item() metrics["World Model Loss"] = world_model_loss.item() wandb.log(metrics) """ if step>total_steps: print("Training finished") break def _upper_bound_minimization(self, last_states, current_states, negative_current_states, predicted_current_states): club_sample = CLUBSample(last_states, current_states, negative_current_states, predicted_current_states) likelihood_loss = club_sample.learning_loss() club_loss = club_sample() return likelihood_loss, club_loss def _past_encoder_loss(self, curr_states_dict, predicted_curr_states_dict): # current state distribution curr_states_dist = curr_states_dict["distribution"] # predicted current state distribution predicted_curr_states_dist = predicted_curr_states_dict["distribution"] # KL divergence loss loss = torch.distributions.kl.kl_divergence(curr_states_dist, predicted_curr_states_dist).mean() return loss def get_features(self, x, momentum=False): if self.args.aug: x = T.RandomCrop((80, 80))(x) # (None,80,80,4) x = T.functional.pad(x, (4, 4, 4, 4), "symmetric") # (None,88,88,4) x = T.RandomCrop((84, 84))(x) # (None,84,84,4) with torch.no_grad(): if momentum: x = self.obs_encoder_momentum(x) else: x = self.obs_encoder(x) return x def _compute_lambda_return(self, rewards, values, discounts, td_lam, last_value): next_values = torch.cat([values[1:], last_value.unsqueeze(0)],0) targets = rewards + discounts * next_values * (1-td_lam) rets =[] last_rew = last_value for t in range(rewards.shape[0]-1, -1, -1): last_rew = targets[t] + discounts[t] * td_lam *(last_rew) rets.append(last_rew) returns = torch.flip(torch.stack(rets), [0]) return returns if __name__ == '__main__': args = parse_args() writer = SummaryWriter() dpi = DPI(args, writer) dpi.train()