From 233ca77aa4f0d08cc252d8549de51831ea949b15 Mon Sep 17 00:00:00 2001 From: VedantDave Date: Thu, 13 Apr 2023 18:39:55 +0200 Subject: [PATCH] Completing initial model and treating memory leak --- DPI/train.py | 466 ++++++++++++++++++++++++++++----------------------- 1 file changed, 260 insertions(+), 206 deletions(-) diff --git a/DPI/train.py b/DPI/train.py index f49adc4..4845d91 100644 --- a/DPI/train.py +++ b/DPI/train.py @@ -1,15 +1,12 @@ -import numpy as np -import torch -import argparse import os -import gym -import time -import json -import dmc2gym - +import gc import copy import tqdm import wandb +import random +import argparse +import numpy as np + import utils from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample @@ -17,13 +14,12 @@ from logger import Logger from video import VideoRecorder from dmc2gym.wrappers import set_global_var +import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T from torch.utils.tensorboard import SummaryWriter - - #from agent.baseline_agent import BaselineAgent #from agent.bisim_agent import BisimAgent #from agent.deepmdp_agent import DeepMDPAgent @@ -38,8 +34,9 @@ def parse_args(): parser.add_argument('--task_name', default='run') parser.add_argument('--image_size', default=84, type=int) parser.add_argument('--channels', default=3, type=int) - parser.add_argument('--action_repeat', default=1, type=int) + parser.add_argument('--action_repeat', default=2, type=int) parser.add_argument('--frame_stack', default=3, type=int) + parser.add_argument('--collection_interval', default=100, type=int) parser.add_argument('--resource_files', type=str) parser.add_argument('--eval_resource_files', type=str) parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) @@ -52,11 +49,11 @@ def parse_args(): parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc']) parser.add_argument('--init_steps', default=10000, type=int) parser.add_argument('--num_train_steps', default=10000, type=int) - parser.add_argument('--batch_size', default=20, type=int) #512 + parser.add_argument('--batch_size', default=30, type=int) #512 parser.add_argument('--state_size', default=256, type=int) parser.add_argument('--hidden_size', default=128, type=int) parser.add_argument('--history_size', default=128, type=int) - parser.add_argument('--num-units', type=int, default=200, help='num hidden units for reward/value/discount models') + parser.add_argument('--num-units', type=int, default=50, help='num hidden units for reward/value/discount models') parser.add_argument('--load_encoder', default=None, type=str) parser.add_argument('--imagine_horizon', default=15, type=str) parser.add_argument('--grad_clip_norm', type=float, default=100.0, help='Gradient clipping norm') @@ -64,15 +61,13 @@ def parse_args(): parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000 parser.add_argument('--num_eval_episodes', default=20, type=int) # value - parser.add_argument('--value_lr', default=1e-4, type=float) + parser.add_argument('--value_lr', default=8e-5, type=float) parser.add_argument('--value_beta', default=0.9, type=float) parser.add_argument('--value_tau', default=0.005, type=float) parser.add_argument('--value_target_update_freq', default=100, type=int) parser.add_argument('--td_lambda', default=0.95, type=int) - # reward - parser.add_argument('--reward_lr', default=1e-4, type=float) # actor - parser.add_argument('--actor_lr', default=1e-4, type=float) + parser.add_argument('--actor_lr', default=8e-5, type=float) parser.add_argument('--actor_beta', default=0.9, type=float) parser.add_argument('--actor_log_std_min', default=-10, type=float) parser.add_argument('--actor_log_std_max', default=2, type=float) @@ -80,7 +75,7 @@ def parse_args(): # world/encoder/decoder parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity']) parser.add_argument('--encoder_feature_dim', default=50, type=int) - parser.add_argument('--world_model_lr', default=1e-3, type=float) + parser.add_argument('--world_model_lr', default=6e-4, type=float) parser.add_argument('--past_transition_lr', default=1e-3, type=float) parser.add_argument('--encoder_lr', default=1e-3, type=float) parser.add_argument('--encoder_tau', default=0.001, type=float) @@ -100,6 +95,7 @@ def parse_args(): # misc parser.add_argument('--seed', default=1, type=int) parser.add_argument('--logging_freq', default=100, type=int) + parser.add_argument('--saving_interval', default=1000, type=int) parser.add_argument('--work_dir', default='.', type=str) parser.add_argument('--save_tb', default=False, action='store_true') parser.add_argument('--save_model', default=False, action='store_true') @@ -107,8 +103,6 @@ def parse_args(): parser.add_argument('--save_video', default=False, action='store_true') parser.add_argument('--transition_model_type', default='', type=str, choices=['', 'deterministic', 'probabilistic', 'ensemble']) parser.add_argument('--render', default=False, action='store_true') - parser.add_argument('--port', default=2000, type=int) - parser.add_argument('--num_likelihood_updates', default=5, type=int) args = parser.parse_args() return args @@ -119,7 +113,7 @@ def parse_args(): class DPI: - def __init__(self, args, writer): + def __init__(self, args): # wandb config #run = wandb.init(project="dpi") @@ -141,6 +135,8 @@ class DPI: # stack several consecutive frames together if self.args.encoder_type.startswith('pixel'): self.env = utils.FrameStack(self.env, k=self.args.frame_stack) + self.env = utils.ActionRepeat(self.env, self.args.action_repeat) + self.env = utils.NormalizeActions(self.env) # create replay buffer self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity, @@ -164,64 +160,64 @@ class DPI: self.obs_encoder = ObservationEncoder( obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84) state_size=self.args.state_size # 128 - ) + ).to(device) self.obs_encoder_momentum = ObservationEncoder( obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84) state_size=self.args.state_size # 128 - ) + ).to(device) self.obs_decoder = ObservationDecoder( state_size=self.args.state_size, # 128 output_shape=(self.args.channels,self.args.image_size,self.args.image_size) # (3,84,84) - ) + ).to(device) self.transition_model = TransitionModel( state_size=self.args.state_size, # 128 hidden_size=self.args.hidden_size, # 256 action_size=self.env.action_space.shape[0], # 6 history_size=self.args.history_size, # 128 - ) + ).to(device) # Actor Model self.actor_model = Actor( state_size=self.args.state_size, # 128 hidden_size=self.args.hidden_size, # 256, action_size=self.env.action_space.shape[0], # 6 - ) + ).to(device) # Value Models self.value_model = ValueModel( state_size=self.args.state_size, # 128 hidden_size=self.args.hidden_size, # 256 - ) + ).to(device) self.target_value_model = ValueModel( state_size=self.args.state_size, # 128 hidden_size=self.args.hidden_size, # 256 - ) + ).to(device) self.reward_model = RewardModel( state_size=self.args.state_size, # 128 hidden_size=self.args.hidden_size, # 256 - ) + ).to(device) # Contrastive Models self.prjoection_head = ProjectionHead( state_size=self.args.state_size, # 128 action_size=self.env.action_space.shape[0], # 6 hidden_size=self.args.hidden_size, # 256 - ) + ).to(device) self.prjoection_head_momentum = ProjectionHead( state_size=self.args.state_size, # 128 action_size=self.env.action_space.shape[0], # 6 hidden_size=self.args.hidden_size, # 256 - ) + ).to(device) self.contrastive_head = ContrastiveHead( hidden_size=self.args.hidden_size, # 256 - ) + ).to(device) # model parameters @@ -237,7 +233,7 @@ class DPI: self.past_transition_opt = torch.optim.Adam(self.past_transition_parameters, self.args.past_transition_lr) # Create Modules - self.world_model_modules = [self.obs_encoder, self.obs_decoder, self.value_model, self.transition_model, self.prjoection_head] + self.world_model_modules = [self.obs_encoder, self.obs_decoder, self.reward_model, self.transition_model, self.prjoection_head] self.value_modules = [self.value_model] self.actor_modules = [self.actor_model] @@ -249,21 +245,27 @@ class DPI: self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt'))) self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt'))) - def collect_sequences(self, episodes): + def collect_sequences(self, episodes, random=True, actor_model=None, encoder_model=None): obs = self.env.reset() done = False + all_rews = [] #video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files) for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'): if args.save_video: self.env.video.init(enabled=True) + epi_reward = 0 for i in range(self.args.episode_length): + if random: + action = self.env.action_space.sample() + else: + with torch.no_grad(): + obs_torch = torch.unsqueeze(torch.tensor(obs).float(),0).to(device) + state = self.obs_encoder(obs_torch)["distribution"].sample() + action = self.actor_model(state).cpu().detach().numpy().squeeze() - action = self.env.action_space.sample() - next_obs, rew, done, _ = self.env.step(action) - self.data_buffer.add(obs, action, next_obs, rew, episode_count+1, done) if args.save_video: @@ -274,184 +276,222 @@ class DPI: done=False else: obs = next_obs + epi_reward += rew + all_rews.append(epi_reward) if args.save_video: self.env.video.save('noisy/%d.mp4' % episode_count) print("Collected {} random episodes".format(episode_count+1)) + return all_rews - def train(self): - # collect experience - self.collect_sequences(self.args.batch_size) - - # Group observations and next_observations by steps from past to present - last_observations = torch.tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float()[:self.args.episode_length-1] - current_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()[:self.args.episode_length-1] - next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()[1:] - actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[:self.args.episode_length-1] - next_actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[1:] - rewards = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"rewards",obs=False)).float()[1:] - - # Preprocessing - last_observations = preprocess_obs(last_observations) - current_observations = preprocess_obs(current_observations) - next_observations = preprocess_obs(next_observations) - - # Initialize transition model states - self.transition_model.init_states(self.args.batch_size, device="cpu") # (N,128) - self.history = self.transition_model.prev_history # (N,128) - - # Train encoder - step = 0 - total_steps = 10000 - metrics = {} + def train(self, step, total_steps): + counter = 0 while step < total_steps: - for i in range(self.args.episode_length-1): - if i > 0: - # Encode observations and next_observations - self.last_states_dict = self.get_features(last_observations[i]) - self.current_states_dict = self.get_features(current_observations[i]) - self.next_states_dict = self.get_features(next_observations[i], momentum=True) - self.action = actions[i] # (N,6) - self.next_action = next_actions[i] # (N,6) - history = self.transition_model.prev_history - - # Encode negative observations - idx = torch.randperm(current_observations[i].shape[0]) # random permutation on batch - random_time_index = torch.randint(0, self.args.episode_length-2, (1,)).item() # random time index - negative_current_observations = current_observations[random_time_index][idx] - self.negative_current_states_dict = self.obs_encoder(negative_current_observations) + + # collect experience + if step !=0: + encoder = self.obs_encoder + actor = self.actor_model + #all_rews = self.collect_sequences(self.args.batch_size, random=True) + all_rews = self.collect_sequences(self.args.batch_size, random=False, actor_model=actor, encoder_model=encoder) + else: + all_rews = self.collect_sequences(self.args.batch_size, random=True) - # Predict current state from past state with transition model - last_states_sample = self.last_states_dict["sample"] - predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history) - self.history = predicted_current_state_dict["history"] + # Group by steps and sample random batch + random_indices = self.data_buffer.sample_random_idx(self.args.batch_size * ((step//self.args.collection_interval)+1)) # random indices for batch + #random_indices = np.arange(self.args.batch_size * ((step//self.args.collection_interval)),self.args.batch_size * ((step//self.args.collection_interval)+1)) + last_observations = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"observations", "cpu", random_indices=random_indices) + current_observations = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"next_observations", device="cpu", random_indices=random_indices) + next_observations = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"next_observations", device="cpu", offset=1, random_indices=random_indices) + actions = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"actions", device=device, is_obs=False, random_indices=random_indices) + next_actions = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"actions", device=device, is_obs=False, offset=1, random_indices=random_indices) + rewards = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"rewards", device=device, is_obs=False, offset=1, random_indices=random_indices) - # Calculate upper bound loss - likeli_loss, ub_loss = self._upper_bound_minimization(self.last_states_dict, - self.current_states_dict, - self.negative_current_states_dict, - predicted_current_state_dict - ) - #likeli_loss = torch.tensor(likeli_loss.numpy(),dtype=torch.float32, requires_grad=True) - #ikeli_loss = likeli_loss.mean() - - # Calculate encoder loss - encoder_loss = self._past_encoder_loss(self.current_states_dict, - predicted_current_state_dict) + # Preprocessing + last_observations = preprocess_obs(last_observations).to(device) + current_observations = preprocess_obs(current_observations).to(device) + next_observations = preprocess_obs(next_observations).to(device) - #total_ub_loss += ub_loss - #total_encoder_loss += encoder_loss - - # contrastive projection - vec_anchor = predicted_current_state_dict["sample"] - vec_positive = self.next_states_dict["sample"].detach() - z_anchor = self.prjoection_head(vec_anchor, self.action) - z_positive = self.prjoection_head_momentum(vec_positive, next_actions[i]).detach() + # Initialize transition model states + self.transition_model.init_states(self.args.batch_size, device) # (N,128) + self.history = self.transition_model.prev_history # (N,128) - # contrastive loss - logits = self.contrastive_head(z_anchor, z_positive) - labels = labels = torch.arange(logits.shape[0]).long() - lb_loss = F.cross_entropy(logits, labels) - - # behaviour learning - with FreezeParameters(self.world_model_modules): - imagine_horizon = self.args.imagine_horizon #np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i) - imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"].detach(), - self.next_action, self.history.detach(), - imagine_horizon) - - # decoder loss - horizon = np.minimum(50-i, imagine_horizon) - obs_dist = self.obs_decoder(imagined_rollout["sample"][:horizon]) - decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon][:,:,:3,:,:])) - - # reward loss - reward_dist = self.reward_model(self.current_states_dict["sample"]) - reward_loss = -torch.mean(reward_dist.log_prob(rewards[:-1])) - - # update models - world_model_loss = encoder_loss + ub_loss + lb_loss + decoder_loss * 1e-2 - self.world_model_opt.zero_grad() - world_model_loss.backward() - nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm) - self.world_model_opt.step() - - # actor loss - with FreezeParameters(self.world_model_modules + self.value_modules): - imag_rew_dist = self.reward_model(imagined_rollout["sample"]) - target_imag_val_dist = self.target_value_model(imagined_rollout["sample"]) - - imag_rews = imag_rew_dist.mean - target_imag_vals = target_imag_val_dist.mean - - discounts = self.args.discount * torch.ones_like(imag_rews).detach() - - self.target_returns = self._compute_lambda_return(imag_rews[:-1], - target_imag_vals[:-1], - discounts[:-1] , - self.args.td_lambda, - target_imag_vals[-1]) - - discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0) - self.discounts = torch.cumprod(discounts, 0).detach() - actor_loss = -torch.mean(self.discounts * self.target_returns) - - # update actor - self.actor_opt.zero_grad() - actor_loss.backward() - nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm) - self.actor_opt.step() - - # value loss - with torch.no_grad(): - value_feat = imagined_rollout["sample"][:-1].detach() - value_targ = self.target_returns.detach() - - value_dist = self.value_model(value_feat) - value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1)) - - # update value - self.value_opt.zero_grad() - value_loss.backward() - nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm) - self.value_opt.step() - - # update target value - if step % self.args.value_target_update_freq == 0: - self.target_value_model = copy.deepcopy(self.value_model) - - # update momentum encoder - soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau) - - # update momentum projection head - soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau) - - step += 1 - - if step % self.args.logging_freq: - writer.add_scalar('Main Loss/World Loss', world_model_loss, step) - writer.add_scalar('Main Models Loss/Encoder Loss', encoder_loss, step) - writer.add_scalar('Main Models Loss/Decoder Loss', decoder_loss, step) - writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss, step) - writer.add_scalar('Actor Critic Loss/Value Loss', value_loss, step) - writer.add_scalar('Actor Critic Loss/Reward Loss', reward_loss, step) - writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss, step) - writer.add_scalar('Bound Loss/Lower Bound Loss', lb_loss, step) + # Train encoder + if step == 0: + step += 1 + for _ in range(self.args.collection_interval // self.args.episode_length+1): + counter += 1 + for i in range(self.args.episode_length-1): + if i > 0: + # Encode observations and next_observations + self.last_states_dict = self.get_features(last_observations[i]) + self.current_states_dict = self.get_features(current_observations[i]) + self.next_states_dict = self.get_features(next_observations[i], momentum=True) + self.action = actions[i] # (N,6) + self.next_action = next_actions[i] # (N,6) + history = self.transition_model.prev_history - """ - if step % self.args.logging_freq: - metrics['Upper Bound Loss'] = ub_loss.item() - metrics['Encoder Loss'] = encoder_loss.item() - metrics['Decoder Loss'] = decoder_loss.item() - metrics["Lower Bound Loss"] = lb_loss.item() - metrics["World Model Loss"] = world_model_loss.item() - wandb.log(metrics) - """ + # Encode negative observations + idx = torch.randperm(current_observations[i].shape[0]) # random permutation on batch + random_time_index = torch.randint(0, self.args.episode_length-2, (1,)).item() # random time index + negative_current_observations = current_observations[random_time_index][idx] + self.negative_current_states_dict = self.obs_encoder(negative_current_observations) - if step>total_steps: - print("Training finished") - break + # Predict current state from past state with transition model + last_states_sample = self.last_states_dict["sample"] + predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history) + self.history = predicted_current_state_dict["history"] + # Calculate upper bound loss + likeli_loss, ub_loss = self._upper_bound_minimization(self.last_states_dict, + self.current_states_dict, + self.negative_current_states_dict, + predicted_current_state_dict + ) + + # Calculate encoder loss + encoder_loss = self._past_encoder_loss(self.current_states_dict, + predicted_current_state_dict) + # contrastive projection + vec_anchor = predicted_current_state_dict["sample"] + vec_positive = self.next_states_dict["sample"].detach() + z_anchor = self.prjoection_head(vec_anchor, self.action) + z_positive = self.prjoection_head_momentum(vec_positive, next_actions[i]).detach() + + # contrastive loss + logits = self.contrastive_head(z_anchor, z_positive) + labels = torch.arange(logits.shape[0]).long().to(device) + lb_loss = F.cross_entropy(logits, labels) + + # behaviour learning + with FreezeParameters(self.world_model_modules): + imagine_horizon = self.args.imagine_horizon #np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i) + imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"].detach(), + self.next_action, self.history.detach(), + imagine_horizon) + + # decoder loss + horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i) + obs_dist = self.obs_decoder(imagined_rollout["sample"][:horizon]) + decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon][:,:,:3,:,:])) + + # reward loss + reward_dist = self.reward_model(self.current_states_dict["sample"]) + reward_loss = -torch.mean(reward_dist.log_prob(rewards[:-1])) + + # update models + world_model_loss = encoder_loss + 100 * ub_loss + lb_loss + reward_loss + decoder_loss * 1e-2 + self.world_model_opt.zero_grad() + world_model_loss.backward() + nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm) + self.world_model_opt.step() + + # update momentum encoder + soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau) + + # update momentum projection head + soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau) + + # actor loss + with FreezeParameters(self.world_model_modules + self.value_modules): + imag_rew_dist = self.reward_model(imagined_rollout["sample"]) + target_imag_val_dist = self.target_value_model(imagined_rollout["sample"]) + + imag_rews = imag_rew_dist.mean + target_imag_vals = target_imag_val_dist.mean + + discounts = self.args.discount * torch.ones_like(imag_rews).detach() + + self.target_returns = self._compute_lambda_return(imag_rews[:-1], + target_imag_vals[:-1], + discounts[:-1] , + self.args.td_lambda, + target_imag_vals[-1]) + + discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0) + self.discounts = torch.cumprod(discounts, 0).detach() + actor_loss = -torch.mean(self.discounts * self.target_returns) + + # update actor + self.actor_opt.zero_grad() + actor_loss.backward() + nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm) + self.actor_opt.step() + + # value loss + with torch.no_grad(): + value_feat = imagined_rollout["sample"][:-1].detach() + value_targ = self.target_returns.detach() + + value_dist = self.value_model(value_feat) + value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1)) + + # update value + self.value_opt.zero_grad() + value_loss.backward() + nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm) + self.value_opt.step() + + # update target value + if step % self.args.value_target_update_freq == 0: + self.target_value_model = copy.deepcopy(self.value_model) + + # counter for reward + count = np.arange((counter-1) * (self.args.batch_size), (counter) * (self.args.batch_size)) + + + if step % self.args.logging_freq: + writer.add_scalar('World Loss/World Loss', world_model_loss.detach().item(), step) + writer.add_scalar('Main Models Loss/Encoder Loss', encoder_loss.detach().item(), step) + writer.add_scalar('Main Models Loss/Decoder Loss', decoder_loss, step) + writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss.detach().item(), step) + writer.add_scalar('Actor Critic Loss/Value Loss', value_loss.detach().item(), step) + writer.add_scalar('Actor Critic Loss/Reward Loss', reward_loss.detach().item(), step) + writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss.detach().item(), step) + writer.add_scalar('Bound Loss/Lower Bound Loss', lb_loss.detach().item(), step) + + step += 1 + if step>total_steps: + print("Training finished") + break + + # save model + if step % self.args.saving_interval == 0: + path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth" + self.save_models(path) + + #torch.cuda.empty_cache() # memory leak issues + + for j in range(len(all_rews)): + writer.add_scalar('Rewards/Rewards', all_rews[j], count[j]) + + + def evaluate(self, env, eval_episodes, render=False): + + episode_rew = np.zeros((eval_episodes)) + + video_images = [[] for _ in range(eval_episodes)] + + for i in range(eval_episodes): + obs = env.reset() + done = False + prev_state = self.rssm.init_state(1, self.device) + prev_action = torch.zeros(1, self.action_size).to(self.device) + + while not done: + with torch.no_grad(): + posterior, action = self.act_with_world_model(obs, prev_state, prev_action) + action = action[0].cpu().numpy() + next_obs, rew, done, _ = env.step(action) + prev_state = posterior + prev_action = torch.tensor(action, dtype=torch.float32).to(self.device).unsqueeze(0) + + episode_rew[i] += rew + + if render: + video_images[i].append(obs['image'].transpose(1,2,0).copy()) + obs = next_obs + return episode_rew, np.array(video_images[:self.args.max_videos_to_save]) def _upper_bound_minimization(self, last_states, current_states, negative_current_states, predicted_current_states): club_sample = CLUBSample(last_states, @@ -469,8 +509,6 @@ class DPI: # predicted current state distribution predicted_curr_states_dist = predicted_curr_states_dict["distribution"] - - # KL divergence loss loss = torch.distributions.kl.kl_divergence(curr_states_dist, predicted_curr_states_dist).mean() @@ -501,11 +539,27 @@ class DPI: returns = torch.flip(torch.stack(rets), [0]) return returns + + def save_models(self, save_path): + torch.save( + {'rssm' : self.transition_model.state_dict(), + 'actor': self.actor_model.state_dict(), + 'reward_model': self.reward_model.state_dict(), + 'obs_encoder': self.obs_encoder.state_dict(), + 'obs_decoder': self.obs_decoder.state_dict(), + 'actor_optimizer': self.actor_opt.state_dict(), + 'value_optimizer': self.value_opt.state_dict(), + 'world_model_optimizer': self.world_model_opt.state_dict(),}, save_path) if __name__ == '__main__': args = parse_args() writer = SummaryWriter() + - dpi = DPI(args, writer) - dpi.train() \ No newline at end of file + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + + step = 0 + total_steps = 10000 + dpi = DPI(args) + dpi.train(step,total_steps) \ No newline at end of file