Completing initial model and treating memory leak

This commit is contained in:
Vedant Dave 2023-04-13 18:39:55 +02:00
parent 3e9d8f7a9c
commit 233ca77aa4

View File

@ -1,15 +1,12 @@
import numpy as np
import torch
import argparse
import os import os
import gym import gc
import time
import json
import dmc2gym
import copy import copy
import tqdm import tqdm
import wandb import wandb
import random
import argparse
import numpy as np
import utils import utils
from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image
from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample
@ -17,13 +14,12 @@ from logger import Logger
from video import VideoRecorder from video import VideoRecorder
from dmc2gym.wrappers import set_global_var from dmc2gym.wrappers import set_global_var
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.transforms as T import torchvision.transforms as T
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
#from agent.baseline_agent import BaselineAgent #from agent.baseline_agent import BaselineAgent
#from agent.bisim_agent import BisimAgent #from agent.bisim_agent import BisimAgent
#from agent.deepmdp_agent import DeepMDPAgent #from agent.deepmdp_agent import DeepMDPAgent
@ -38,8 +34,9 @@ def parse_args():
parser.add_argument('--task_name', default='run') parser.add_argument('--task_name', default='run')
parser.add_argument('--image_size', default=84, type=int) parser.add_argument('--image_size', default=84, type=int)
parser.add_argument('--channels', default=3, type=int) parser.add_argument('--channels', default=3, type=int)
parser.add_argument('--action_repeat', default=1, type=int) parser.add_argument('--action_repeat', default=2, type=int)
parser.add_argument('--frame_stack', default=3, type=int) parser.add_argument('--frame_stack', default=3, type=int)
parser.add_argument('--collection_interval', default=100, type=int)
parser.add_argument('--resource_files', type=str) parser.add_argument('--resource_files', type=str)
parser.add_argument('--eval_resource_files', type=str) parser.add_argument('--eval_resource_files', type=str)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
@ -52,11 +49,11 @@ def parse_args():
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc']) parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
parser.add_argument('--init_steps', default=10000, type=int) parser.add_argument('--init_steps', default=10000, type=int)
parser.add_argument('--num_train_steps', default=10000, type=int) parser.add_argument('--num_train_steps', default=10000, type=int)
parser.add_argument('--batch_size', default=20, type=int) #512 parser.add_argument('--batch_size', default=30, type=int) #512
parser.add_argument('--state_size', default=256, type=int) parser.add_argument('--state_size', default=256, type=int)
parser.add_argument('--hidden_size', default=128, type=int) parser.add_argument('--hidden_size', default=128, type=int)
parser.add_argument('--history_size', default=128, type=int) parser.add_argument('--history_size', default=128, type=int)
parser.add_argument('--num-units', type=int, default=200, help='num hidden units for reward/value/discount models') parser.add_argument('--num-units', type=int, default=50, help='num hidden units for reward/value/discount models')
parser.add_argument('--load_encoder', default=None, type=str) parser.add_argument('--load_encoder', default=None, type=str)
parser.add_argument('--imagine_horizon', default=15, type=str) parser.add_argument('--imagine_horizon', default=15, type=str)
parser.add_argument('--grad_clip_norm', type=float, default=100.0, help='Gradient clipping norm') parser.add_argument('--grad_clip_norm', type=float, default=100.0, help='Gradient clipping norm')
@ -64,15 +61,13 @@ def parse_args():
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000 parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
parser.add_argument('--num_eval_episodes', default=20, type=int) parser.add_argument('--num_eval_episodes', default=20, type=int)
# value # value
parser.add_argument('--value_lr', default=1e-4, type=float) parser.add_argument('--value_lr', default=8e-5, type=float)
parser.add_argument('--value_beta', default=0.9, type=float) parser.add_argument('--value_beta', default=0.9, type=float)
parser.add_argument('--value_tau', default=0.005, type=float) parser.add_argument('--value_tau', default=0.005, type=float)
parser.add_argument('--value_target_update_freq', default=100, type=int) parser.add_argument('--value_target_update_freq', default=100, type=int)
parser.add_argument('--td_lambda', default=0.95, type=int) parser.add_argument('--td_lambda', default=0.95, type=int)
# reward
parser.add_argument('--reward_lr', default=1e-4, type=float)
# actor # actor
parser.add_argument('--actor_lr', default=1e-4, type=float) parser.add_argument('--actor_lr', default=8e-5, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float) parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float) parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float) parser.add_argument('--actor_log_std_max', default=2, type=float)
@ -80,7 +75,7 @@ def parse_args():
# world/encoder/decoder # world/encoder/decoder
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity']) parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
parser.add_argument('--encoder_feature_dim', default=50, type=int) parser.add_argument('--encoder_feature_dim', default=50, type=int)
parser.add_argument('--world_model_lr', default=1e-3, type=float) parser.add_argument('--world_model_lr', default=6e-4, type=float)
parser.add_argument('--past_transition_lr', default=1e-3, type=float) parser.add_argument('--past_transition_lr', default=1e-3, type=float)
parser.add_argument('--encoder_lr', default=1e-3, type=float) parser.add_argument('--encoder_lr', default=1e-3, type=float)
parser.add_argument('--encoder_tau', default=0.001, type=float) parser.add_argument('--encoder_tau', default=0.001, type=float)
@ -100,6 +95,7 @@ def parse_args():
# misc # misc
parser.add_argument('--seed', default=1, type=int) parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--logging_freq', default=100, type=int) parser.add_argument('--logging_freq', default=100, type=int)
parser.add_argument('--saving_interval', default=1000, type=int)
parser.add_argument('--work_dir', default='.', type=str) parser.add_argument('--work_dir', default='.', type=str)
parser.add_argument('--save_tb', default=False, action='store_true') parser.add_argument('--save_tb', default=False, action='store_true')
parser.add_argument('--save_model', default=False, action='store_true') parser.add_argument('--save_model', default=False, action='store_true')
@ -107,8 +103,6 @@ def parse_args():
parser.add_argument('--save_video', default=False, action='store_true') parser.add_argument('--save_video', default=False, action='store_true')
parser.add_argument('--transition_model_type', default='', type=str, choices=['', 'deterministic', 'probabilistic', 'ensemble']) parser.add_argument('--transition_model_type', default='', type=str, choices=['', 'deterministic', 'probabilistic', 'ensemble'])
parser.add_argument('--render', default=False, action='store_true') parser.add_argument('--render', default=False, action='store_true')
parser.add_argument('--port', default=2000, type=int)
parser.add_argument('--num_likelihood_updates', default=5, type=int)
args = parser.parse_args() args = parser.parse_args()
return args return args
@ -119,7 +113,7 @@ def parse_args():
class DPI: class DPI:
def __init__(self, args, writer): def __init__(self, args):
# wandb config # wandb config
#run = wandb.init(project="dpi") #run = wandb.init(project="dpi")
@ -141,6 +135,8 @@ class DPI:
# stack several consecutive frames together # stack several consecutive frames together
if self.args.encoder_type.startswith('pixel'): if self.args.encoder_type.startswith('pixel'):
self.env = utils.FrameStack(self.env, k=self.args.frame_stack) self.env = utils.FrameStack(self.env, k=self.args.frame_stack)
self.env = utils.ActionRepeat(self.env, self.args.action_repeat)
self.env = utils.NormalizeActions(self.env)
# create replay buffer # create replay buffer
self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity, self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity,
@ -164,64 +160,64 @@ class DPI:
self.obs_encoder = ObservationEncoder( self.obs_encoder = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84) obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84)
state_size=self.args.state_size # 128 state_size=self.args.state_size # 128
) ).to(device)
self.obs_encoder_momentum = ObservationEncoder( self.obs_encoder_momentum = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84) obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84)
state_size=self.args.state_size # 128 state_size=self.args.state_size # 128
) ).to(device)
self.obs_decoder = ObservationDecoder( self.obs_decoder = ObservationDecoder(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
output_shape=(self.args.channels,self.args.image_size,self.args.image_size) # (3,84,84) output_shape=(self.args.channels,self.args.image_size,self.args.image_size) # (3,84,84)
) ).to(device)
self.transition_model = TransitionModel( self.transition_model = TransitionModel(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
action_size=self.env.action_space.shape[0], # 6 action_size=self.env.action_space.shape[0], # 6
history_size=self.args.history_size, # 128 history_size=self.args.history_size, # 128
) ).to(device)
# Actor Model # Actor Model
self.actor_model = Actor( self.actor_model = Actor(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256, hidden_size=self.args.hidden_size, # 256,
action_size=self.env.action_space.shape[0], # 6 action_size=self.env.action_space.shape[0], # 6
) ).to(device)
# Value Models # Value Models
self.value_model = ValueModel( self.value_model = ValueModel(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
) ).to(device)
self.target_value_model = ValueModel( self.target_value_model = ValueModel(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
) ).to(device)
self.reward_model = RewardModel( self.reward_model = RewardModel(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
) ).to(device)
# Contrastive Models # Contrastive Models
self.prjoection_head = ProjectionHead( self.prjoection_head = ProjectionHead(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
action_size=self.env.action_space.shape[0], # 6 action_size=self.env.action_space.shape[0], # 6
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
) ).to(device)
self.prjoection_head_momentum = ProjectionHead( self.prjoection_head_momentum = ProjectionHead(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
action_size=self.env.action_space.shape[0], # 6 action_size=self.env.action_space.shape[0], # 6
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
) ).to(device)
self.contrastive_head = ContrastiveHead( self.contrastive_head = ContrastiveHead(
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
) ).to(device)
# model parameters # model parameters
@ -237,7 +233,7 @@ class DPI:
self.past_transition_opt = torch.optim.Adam(self.past_transition_parameters, self.args.past_transition_lr) self.past_transition_opt = torch.optim.Adam(self.past_transition_parameters, self.args.past_transition_lr)
# Create Modules # Create Modules
self.world_model_modules = [self.obs_encoder, self.obs_decoder, self.value_model, self.transition_model, self.prjoection_head] self.world_model_modules = [self.obs_encoder, self.obs_decoder, self.reward_model, self.transition_model, self.prjoection_head]
self.value_modules = [self.value_model] self.value_modules = [self.value_model]
self.actor_modules = [self.actor_model] self.actor_modules = [self.actor_model]
@ -249,21 +245,27 @@ class DPI:
self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt'))) self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt')))
self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt'))) self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt')))
def collect_sequences(self, episodes): def collect_sequences(self, episodes, random=True, actor_model=None, encoder_model=None):
obs = self.env.reset() obs = self.env.reset()
done = False done = False
all_rews = []
#video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files) #video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'): for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'):
if args.save_video: if args.save_video:
self.env.video.init(enabled=True) self.env.video.init(enabled=True)
epi_reward = 0
for i in range(self.args.episode_length): for i in range(self.args.episode_length):
if random:
action = self.env.action_space.sample()
else:
with torch.no_grad():
obs_torch = torch.unsqueeze(torch.tensor(obs).float(),0).to(device)
state = self.obs_encoder(obs_torch)["distribution"].sample()
action = self.actor_model(state).cpu().detach().numpy().squeeze()
action = self.env.action_space.sample()
next_obs, rew, done, _ = self.env.step(action) next_obs, rew, done, _ = self.env.step(action)
self.data_buffer.add(obs, action, next_obs, rew, episode_count+1, done) self.data_buffer.add(obs, action, next_obs, rew, episode_count+1, done)
if args.save_video: if args.save_video:
@ -274,184 +276,222 @@ class DPI:
done=False done=False
else: else:
obs = next_obs obs = next_obs
epi_reward += rew
all_rews.append(epi_reward)
if args.save_video: if args.save_video:
self.env.video.save('noisy/%d.mp4' % episode_count) self.env.video.save('noisy/%d.mp4' % episode_count)
print("Collected {} random episodes".format(episode_count+1)) print("Collected {} random episodes".format(episode_count+1))
return all_rews
def train(self): def train(self, step, total_steps):
# collect experience counter = 0
self.collect_sequences(self.args.batch_size)
# Group observations and next_observations by steps from past to present
last_observations = torch.tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float()[:self.args.episode_length-1]
current_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()[:self.args.episode_length-1]
next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()[1:]
actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[:self.args.episode_length-1]
next_actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[1:]
rewards = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"rewards",obs=False)).float()[1:]
# Preprocessing
last_observations = preprocess_obs(last_observations)
current_observations = preprocess_obs(current_observations)
next_observations = preprocess_obs(next_observations)
# Initialize transition model states
self.transition_model.init_states(self.args.batch_size, device="cpu") # (N,128)
self.history = self.transition_model.prev_history # (N,128)
# Train encoder
step = 0
total_steps = 10000
metrics = {}
while step < total_steps: while step < total_steps:
for i in range(self.args.episode_length-1):
if i > 0: # collect experience
# Encode observations and next_observations if step !=0:
self.last_states_dict = self.get_features(last_observations[i]) encoder = self.obs_encoder
self.current_states_dict = self.get_features(current_observations[i]) actor = self.actor_model
self.next_states_dict = self.get_features(next_observations[i], momentum=True) #all_rews = self.collect_sequences(self.args.batch_size, random=True)
self.action = actions[i] # (N,6) all_rews = self.collect_sequences(self.args.batch_size, random=False, actor_model=actor, encoder_model=encoder)
self.next_action = next_actions[i] # (N,6) else:
history = self.transition_model.prev_history all_rews = self.collect_sequences(self.args.batch_size, random=True)
# Encode negative observations
idx = torch.randperm(current_observations[i].shape[0]) # random permutation on batch
random_time_index = torch.randint(0, self.args.episode_length-2, (1,)).item() # random time index
negative_current_observations = current_observations[random_time_index][idx]
self.negative_current_states_dict = self.obs_encoder(negative_current_observations)
# Predict current state from past state with transition model # Group by steps and sample random batch
last_states_sample = self.last_states_dict["sample"] random_indices = self.data_buffer.sample_random_idx(self.args.batch_size * ((step//self.args.collection_interval)+1)) # random indices for batch
predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history) #random_indices = np.arange(self.args.batch_size * ((step//self.args.collection_interval)),self.args.batch_size * ((step//self.args.collection_interval)+1))
self.history = predicted_current_state_dict["history"] last_observations = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"observations", "cpu", random_indices=random_indices)
current_observations = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"next_observations", device="cpu", random_indices=random_indices)
next_observations = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"next_observations", device="cpu", offset=1, random_indices=random_indices)
actions = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"actions", device=device, is_obs=False, random_indices=random_indices)
next_actions = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"actions", device=device, is_obs=False, offset=1, random_indices=random_indices)
rewards = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"rewards", device=device, is_obs=False, offset=1, random_indices=random_indices)
# Calculate upper bound loss # Preprocessing
likeli_loss, ub_loss = self._upper_bound_minimization(self.last_states_dict, last_observations = preprocess_obs(last_observations).to(device)
self.current_states_dict, current_observations = preprocess_obs(current_observations).to(device)
self.negative_current_states_dict, next_observations = preprocess_obs(next_observations).to(device)
predicted_current_state_dict
)
#likeli_loss = torch.tensor(likeli_loss.numpy(),dtype=torch.float32, requires_grad=True)
#ikeli_loss = likeli_loss.mean()
# Calculate encoder loss
encoder_loss = self._past_encoder_loss(self.current_states_dict,
predicted_current_state_dict)
#total_ub_loss += ub_loss # Initialize transition model states
#total_encoder_loss += encoder_loss self.transition_model.init_states(self.args.batch_size, device) # (N,128)
self.history = self.transition_model.prev_history # (N,128)
# contrastive projection
vec_anchor = predicted_current_state_dict["sample"]
vec_positive = self.next_states_dict["sample"].detach()
z_anchor = self.prjoection_head(vec_anchor, self.action)
z_positive = self.prjoection_head_momentum(vec_positive, next_actions[i]).detach()
# contrastive loss # Train encoder
logits = self.contrastive_head(z_anchor, z_positive) if step == 0:
labels = labels = torch.arange(logits.shape[0]).long() step += 1
lb_loss = F.cross_entropy(logits, labels) for _ in range(self.args.collection_interval // self.args.episode_length+1):
counter += 1
# behaviour learning for i in range(self.args.episode_length-1):
with FreezeParameters(self.world_model_modules): if i > 0:
imagine_horizon = self.args.imagine_horizon #np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i) # Encode observations and next_observations
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"].detach(), self.last_states_dict = self.get_features(last_observations[i])
self.next_action, self.history.detach(), self.current_states_dict = self.get_features(current_observations[i])
imagine_horizon) self.next_states_dict = self.get_features(next_observations[i], momentum=True)
self.action = actions[i] # (N,6)
# decoder loss self.next_action = next_actions[i] # (N,6)
horizon = np.minimum(50-i, imagine_horizon) history = self.transition_model.prev_history
obs_dist = self.obs_decoder(imagined_rollout["sample"][:horizon])
decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon][:,:,:3,:,:]))
# reward loss
reward_dist = self.reward_model(self.current_states_dict["sample"])
reward_loss = -torch.mean(reward_dist.log_prob(rewards[:-1]))
# update models
world_model_loss = encoder_loss + ub_loss + lb_loss + decoder_loss * 1e-2
self.world_model_opt.zero_grad()
world_model_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm)
self.world_model_opt.step()
# actor loss
with FreezeParameters(self.world_model_modules + self.value_modules):
imag_rew_dist = self.reward_model(imagined_rollout["sample"])
target_imag_val_dist = self.target_value_model(imagined_rollout["sample"])
imag_rews = imag_rew_dist.mean
target_imag_vals = target_imag_val_dist.mean
discounts = self.args.discount * torch.ones_like(imag_rews).detach()
self.target_returns = self._compute_lambda_return(imag_rews[:-1],
target_imag_vals[:-1],
discounts[:-1] ,
self.args.td_lambda,
target_imag_vals[-1])
discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0)
self.discounts = torch.cumprod(discounts, 0).detach()
actor_loss = -torch.mean(self.discounts * self.target_returns)
# update actor
self.actor_opt.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm)
self.actor_opt.step()
# value loss
with torch.no_grad():
value_feat = imagined_rollout["sample"][:-1].detach()
value_targ = self.target_returns.detach()
value_dist = self.value_model(value_feat)
value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1))
# update value
self.value_opt.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm)
self.value_opt.step()
# update target value
if step % self.args.value_target_update_freq == 0:
self.target_value_model = copy.deepcopy(self.value_model)
# update momentum encoder
soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau)
# update momentum projection head
soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau)
step += 1
if step % self.args.logging_freq:
writer.add_scalar('Main Loss/World Loss', world_model_loss, step)
writer.add_scalar('Main Models Loss/Encoder Loss', encoder_loss, step)
writer.add_scalar('Main Models Loss/Decoder Loss', decoder_loss, step)
writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss, step)
writer.add_scalar('Actor Critic Loss/Value Loss', value_loss, step)
writer.add_scalar('Actor Critic Loss/Reward Loss', reward_loss, step)
writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss, step)
writer.add_scalar('Bound Loss/Lower Bound Loss', lb_loss, step)
""" # Encode negative observations
if step % self.args.logging_freq: idx = torch.randperm(current_observations[i].shape[0]) # random permutation on batch
metrics['Upper Bound Loss'] = ub_loss.item() random_time_index = torch.randint(0, self.args.episode_length-2, (1,)).item() # random time index
metrics['Encoder Loss'] = encoder_loss.item() negative_current_observations = current_observations[random_time_index][idx]
metrics['Decoder Loss'] = decoder_loss.item() self.negative_current_states_dict = self.obs_encoder(negative_current_observations)
metrics["Lower Bound Loss"] = lb_loss.item()
metrics["World Model Loss"] = world_model_loss.item()
wandb.log(metrics)
"""
if step>total_steps: # Predict current state from past state with transition model
print("Training finished") last_states_sample = self.last_states_dict["sample"]
break predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history)
self.history = predicted_current_state_dict["history"]
# Calculate upper bound loss
likeli_loss, ub_loss = self._upper_bound_minimization(self.last_states_dict,
self.current_states_dict,
self.negative_current_states_dict,
predicted_current_state_dict
)
# Calculate encoder loss
encoder_loss = self._past_encoder_loss(self.current_states_dict,
predicted_current_state_dict)
# contrastive projection
vec_anchor = predicted_current_state_dict["sample"]
vec_positive = self.next_states_dict["sample"].detach()
z_anchor = self.prjoection_head(vec_anchor, self.action)
z_positive = self.prjoection_head_momentum(vec_positive, next_actions[i]).detach()
# contrastive loss
logits = self.contrastive_head(z_anchor, z_positive)
labels = torch.arange(logits.shape[0]).long().to(device)
lb_loss = F.cross_entropy(logits, labels)
# behaviour learning
with FreezeParameters(self.world_model_modules):
imagine_horizon = self.args.imagine_horizon #np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"].detach(),
self.next_action, self.history.detach(),
imagine_horizon)
# decoder loss
horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
obs_dist = self.obs_decoder(imagined_rollout["sample"][:horizon])
decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon][:,:,:3,:,:]))
# reward loss
reward_dist = self.reward_model(self.current_states_dict["sample"])
reward_loss = -torch.mean(reward_dist.log_prob(rewards[:-1]))
# update models
world_model_loss = encoder_loss + 100 * ub_loss + lb_loss + reward_loss + decoder_loss * 1e-2
self.world_model_opt.zero_grad()
world_model_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm)
self.world_model_opt.step()
# update momentum encoder
soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau)
# update momentum projection head
soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau)
# actor loss
with FreezeParameters(self.world_model_modules + self.value_modules):
imag_rew_dist = self.reward_model(imagined_rollout["sample"])
target_imag_val_dist = self.target_value_model(imagined_rollout["sample"])
imag_rews = imag_rew_dist.mean
target_imag_vals = target_imag_val_dist.mean
discounts = self.args.discount * torch.ones_like(imag_rews).detach()
self.target_returns = self._compute_lambda_return(imag_rews[:-1],
target_imag_vals[:-1],
discounts[:-1] ,
self.args.td_lambda,
target_imag_vals[-1])
discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0)
self.discounts = torch.cumprod(discounts, 0).detach()
actor_loss = -torch.mean(self.discounts * self.target_returns)
# update actor
self.actor_opt.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm)
self.actor_opt.step()
# value loss
with torch.no_grad():
value_feat = imagined_rollout["sample"][:-1].detach()
value_targ = self.target_returns.detach()
value_dist = self.value_model(value_feat)
value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1))
# update value
self.value_opt.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm)
self.value_opt.step()
# update target value
if step % self.args.value_target_update_freq == 0:
self.target_value_model = copy.deepcopy(self.value_model)
# counter for reward
count = np.arange((counter-1) * (self.args.batch_size), (counter) * (self.args.batch_size))
if step % self.args.logging_freq:
writer.add_scalar('World Loss/World Loss', world_model_loss.detach().item(), step)
writer.add_scalar('Main Models Loss/Encoder Loss', encoder_loss.detach().item(), step)
writer.add_scalar('Main Models Loss/Decoder Loss', decoder_loss, step)
writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss.detach().item(), step)
writer.add_scalar('Actor Critic Loss/Value Loss', value_loss.detach().item(), step)
writer.add_scalar('Actor Critic Loss/Reward Loss', reward_loss.detach().item(), step)
writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss.detach().item(), step)
writer.add_scalar('Bound Loss/Lower Bound Loss', lb_loss.detach().item(), step)
step += 1
if step>total_steps:
print("Training finished")
break
# save model
if step % self.args.saving_interval == 0:
path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth"
self.save_models(path)
#torch.cuda.empty_cache() # memory leak issues
for j in range(len(all_rews)):
writer.add_scalar('Rewards/Rewards', all_rews[j], count[j])
def evaluate(self, env, eval_episodes, render=False):
episode_rew = np.zeros((eval_episodes))
video_images = [[] for _ in range(eval_episodes)]
for i in range(eval_episodes):
obs = env.reset()
done = False
prev_state = self.rssm.init_state(1, self.device)
prev_action = torch.zeros(1, self.action_size).to(self.device)
while not done:
with torch.no_grad():
posterior, action = self.act_with_world_model(obs, prev_state, prev_action)
action = action[0].cpu().numpy()
next_obs, rew, done, _ = env.step(action)
prev_state = posterior
prev_action = torch.tensor(action, dtype=torch.float32).to(self.device).unsqueeze(0)
episode_rew[i] += rew
if render:
video_images[i].append(obs['image'].transpose(1,2,0).copy())
obs = next_obs
return episode_rew, np.array(video_images[:self.args.max_videos_to_save])
def _upper_bound_minimization(self, last_states, current_states, negative_current_states, predicted_current_states): def _upper_bound_minimization(self, last_states, current_states, negative_current_states, predicted_current_states):
club_sample = CLUBSample(last_states, club_sample = CLUBSample(last_states,
@ -469,8 +509,6 @@ class DPI:
# predicted current state distribution # predicted current state distribution
predicted_curr_states_dist = predicted_curr_states_dict["distribution"] predicted_curr_states_dist = predicted_curr_states_dict["distribution"]
# KL divergence loss # KL divergence loss
loss = torch.distributions.kl.kl_divergence(curr_states_dist, predicted_curr_states_dist).mean() loss = torch.distributions.kl.kl_divergence(curr_states_dist, predicted_curr_states_dist).mean()
@ -501,11 +539,27 @@ class DPI:
returns = torch.flip(torch.stack(rets), [0]) returns = torch.flip(torch.stack(rets), [0])
return returns return returns
def save_models(self, save_path):
torch.save(
{'rssm' : self.transition_model.state_dict(),
'actor': self.actor_model.state_dict(),
'reward_model': self.reward_model.state_dict(),
'obs_encoder': self.obs_encoder.state_dict(),
'obs_decoder': self.obs_decoder.state_dict(),
'actor_optimizer': self.actor_opt.state_dict(),
'value_optimizer': self.value_opt.state_dict(),
'world_model_optimizer': self.world_model_opt.state_dict(),}, save_path)
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
writer = SummaryWriter() writer = SummaryWriter()
dpi = DPI(args, writer) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dpi.train()
step = 0
total_steps = 10000
dpi = DPI(args)
dpi.train(step,total_steps)