Compare commits

..

3 Commits

Author SHA1 Message Date
ac714e3495 Correct history with detach 2023-04-10 20:18:39 +02:00
de17cab9f5 Add MOCO to introduce lower bound loss 2023-04-10 20:18:17 +02:00
05dd20cdfa Add a class to freeze parameters 2023-04-10 20:17:44 +02:00
3 changed files with 182 additions and 78 deletions

View File

@ -19,7 +19,7 @@ class ObservationEncoder(nn.Module):
input_channels = obs_shape[0] if i == 0 else output_channels input_channels = obs_shape[0] if i == 0 else output_channels
output_channels = num_filters * (2 ** i) output_channels = num_filters * (2 ** i)
layers.append(nn.Conv2d(in_channels=input_channels, out_channels= output_channels, kernel_size=4, stride=2)) layers.append(nn.Conv2d(in_channels=input_channels, out_channels= output_channels, kernel_size=4, stride=2))
layers.append(nn.ReLU()) layers.append(nn.LeakyReLU())
self.convs = nn.Sequential(*layers) self.convs = nn.Sequential(*layers)
@ -196,7 +196,8 @@ class TransitionModel(nn.Module):
def imagine_step(self, prev_state, prev_action, prev_history): def imagine_step(self, prev_state, prev_action, prev_history):
state_action = self.act_fn(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1))) state_action = self.act_fn(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1)))
history = self.history_cell(torch.cat([state_action, prev_history], dim=-1), prev_history) prev_hist = prev_history.detach()
history = self.history_cell(torch.cat([state_action, prev_hist], dim=-1), prev_hist)
state_prior = self.fc_state_prior(torch.cat([history, prev_state, prev_action], dim=-1)) state_prior = self.fc_state_prior(torch.cat([history, prev_state, prev_action], dim=-1))
state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1) state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1)

View File

@ -10,14 +10,17 @@ import dmc2gym
import tqdm import tqdm
import wandb import wandb
import utils import utils
from utils import ReplayBuffer, make_env, save_image from utils import ReplayBuffer, FreezeParameters, make_env, soft_update_params, save_image
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample, Actor, ValueModel, RewardModel from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample
from logger import Logger from logger import Logger
from video import VideoRecorder from video import VideoRecorder
from dmc2gym.wrappers import set_global_var from dmc2gym.wrappers import set_global_var
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T import torchvision.transforms as T
#from agent.baseline_agent import BaselineAgent #from agent.baseline_agent import BaselineAgent
#from agent.bisim_agent import BisimAgent #from agent.bisim_agent import BisimAgent
#from agent.deepmdp_agent import DeepMDPAgent #from agent.deepmdp_agent import DeepMDPAgent
@ -53,23 +56,27 @@ def parse_args():
parser.add_argument('--num-units', type=int, default=200, help='num hidden units for reward/value/discount models') parser.add_argument('--num-units', type=int, default=200, help='num hidden units for reward/value/discount models')
parser.add_argument('--load_encoder', default=None, type=str) parser.add_argument('--load_encoder', default=None, type=str)
parser.add_argument('--imagine_horizon', default=15, type=str) parser.add_argument('--imagine_horizon', default=15, type=str)
parser.add_argument('--grad_clip_norm', type=float, default=100.0, help='Gradient clipping norm')
# eval # eval
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000 parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
parser.add_argument('--num_eval_episodes', default=20, type=int) parser.add_argument('--num_eval_episodes', default=20, type=int)
# critic # value
parser.add_argument('--critic_lr', default=1e-3, type=float) parser.add_argument('--value_lr', default=1e-4, type=float)
parser.add_argument('--critic_beta', default=0.9, type=float) parser.add_argument('--value_beta', default=0.9, type=float)
parser.add_argument('--critic_tau', default=0.005, type=float) parser.add_argument('--value_tau', default=0.005, type=float)
parser.add_argument('--critic_target_update_freq', default=2, type=int) parser.add_argument('--value_target_update_freq', default=2, type=int)
# reward
parser.add_argument('--reward_lr', default=1e-4, type=float)
# actor # actor
parser.add_argument('--actor_lr', default=1e-3, type=float) parser.add_argument('--actor_lr', default=1e-4, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float) parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float) parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float) parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int) parser.add_argument('--actor_update_freq', default=2, type=int)
# encoder/decoder # world/encoder/decoder
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity']) parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
parser.add_argument('--encoder_feature_dim', default=50, type=int) parser.add_argument('--encoder_feature_dim', default=50, type=int)
parser.add_argument('--world_model_lr', default=1e-3, type=float)
parser.add_argument('--encoder_lr', default=1e-3, type=float) parser.add_argument('--encoder_lr', default=1e-3, type=float)
parser.add_argument('--encoder_tau', default=0.005, type=float) parser.add_argument('--encoder_tau', default=0.005, type=float)
parser.add_argument('--encoder_stride', default=1, type=int) parser.add_argument('--encoder_stride', default=1, type=int)
@ -79,6 +86,7 @@ def parse_args():
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float) parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
parser.add_argument('--num_layers', default=4, type=int) parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int) parser.add_argument('--num_filters', default=32, type=int)
parser.add_argument('--aug', action='store_true')
# sac # sac
parser.add_argument('--discount', default=0.99, type=float) parser.add_argument('--discount', default=0.99, type=float)
parser.add_argument('--init_temperature', default=0.01, type=float) parser.add_argument('--init_temperature', default=0.01, type=float)
@ -154,6 +162,7 @@ class DPI:
self.build_models(use_saved=False, saved_model_dir=self.model_dir) self.build_models(use_saved=False, saved_model_dir=self.model_dir)
def build_models(self, use_saved, saved_model_dir=None): def build_models(self, use_saved, saved_model_dir=None):
# World Models
self.obs_encoder = ObservationEncoder( self.obs_encoder = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84) obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
state_size=self.args.state_size # 128 state_size=self.args.state_size # 128
@ -176,12 +185,14 @@ class DPI:
history_size=self.args.history_size, # 128 history_size=self.args.history_size, # 128
) )
self.action_model = Actor( # Actor Model
self.actor_model = Actor(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256, hidden_size=self.args.hidden_size, # 256,
action_size=self.env.action_space.shape[0], # 6 action_size=self.env.action_space.shape[0], # 6
) )
# Value Models
self.value_model = ValueModel( self.value_model = ValueModel(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
@ -196,13 +207,39 @@ class DPI:
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
) )
# Contrastive Models
self.prjoection_head = ProjectionHead(
state_size=self.args.state_size, # 128
action_size=self.env.action_space.shape[0], # 6
hidden_size=self.args.hidden_size, # 256
)
self.prjoection_head_momentum = ProjectionHead(
state_size=self.args.state_size, # 128
action_size=self.env.action_space.shape[0], # 6
hidden_size=self.args.hidden_size, # 256
)
self.contrastive_head = ContrastiveHead(
hidden_size=self.args.hidden_size, # 256
)
# model parameters # model parameters
self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_encoder_momentum.parameters()) + \ self.world_model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_decoder.parameters()) + \
list(self.obs_decoder.parameters()) + list(self.transition_model.parameters()) list(self.value_model.parameters()) + list(self.transition_model.parameters()) + \
list(self.prjoection_head.parameters())
# optimizer # optimizers
self.optimizer = torch.optim.Adam(self.model_parameters, lr=self.args.encoder_lr) self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr)
self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr)
self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr)
# Create Modules
self.world_model_modules = [self.obs_encoder, self.obs_decoder, self.value_model, self.transition_model, self.prjoection_head]
self.value_modules = [self.value_model]
self.actor_modules = [self.actor_model]
if use_saved: if use_saved:
self._use_saved_models(saved_model_dir) self._use_saved_models(saved_model_dir)
@ -214,6 +251,8 @@ class DPI:
def collect_sequences(self, episodes): def collect_sequences(self, episodes):
obs = self.env.reset() obs = self.env.reset()
self.ob_mean = np.mean(obs, 0).astype(np.float32)
self.ob_std = np.std(obs, 0).mean().astype(np.float32)
#obs_clean = self.env_clean.reset() #obs_clean = self.env_clean.reset()
done = False done = False
@ -265,48 +304,84 @@ class DPI:
self.history = self.transition_model.prev_history # (N,128) self.history = self.transition_model.prev_history # (N,128)
# Train encoder # Train encoder
total_ub_loss = 0 step = 0
total_encoder_loss = 0 total_steps = 10000
for i in range(self.args.episode_length-1): while step < total_steps:
if i > 0: for i in range(self.args.episode_length-1):
# Encode observations and next_observations if i > 0:
self.last_states_dict = self.obs_encoder(last_observations[i]) # Encode observations and next_observations
self.current_states_dict = self.obs_encoder(current_observations[i]) self.last_states_dict = self.get_features(last_observations[i])
self.next_states_dict = self.obs_encoder_momentum(next_observations[i]) self.current_states_dict = self.get_features(current_observations[i])
self.action = actions[i] # (N,6) self.next_states_dict = self.get_features(next_observations[i], momentum=True)
history = self.transition_model.prev_history self.action = actions[i] # (N,6)
history = self.transition_model.prev_history
# Encode negative observations
idx = torch.randperm(current_observations[i].shape[0]) # random permutation on batch # Encode negative observations
random_time_index = torch.randint(0, self.args.episode_length-2, (1,)).item() # random time index idx = torch.randperm(current_observations[i].shape[0]) # random permutation on batch
negative_current_observations = current_observations[random_time_index][idx] random_time_index = torch.randint(0, self.args.episode_length-2, (1,)).item() # random time index
self.negative_current_states_dict = self.obs_encoder(negative_current_observations) negative_current_observations = current_observations[random_time_index][idx]
self.negative_current_states_dict = self.obs_encoder(negative_current_observations)
# Predict current state from past state with transition model # Predict current state from past state with transition model
last_states_sample = self.last_states_dict["sample"] last_states_sample = self.last_states_dict["sample"]
predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history) predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history)
self.history = predicted_current_state_dict["history"] self.history = predicted_current_state_dict["history"]
# Calculate upper bound loss
ub_loss = self._upper_bound_minimization(self.last_states_dict,
self.current_states_dict,
self.negative_current_states_dict,
predicted_current_state_dict
)
# Calculate encoder loss
encoder_loss = self._past_encoder_loss(self.current_states_dict,
predicted_current_state_dict)
total_ub_loss += ub_loss
total_encoder_loss += encoder_loss
imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"], self.action, self.history, imagine_horizon) # Calculate upper bound loss
ub_loss = self._upper_bound_minimization(self.last_states_dict,
self.current_states_dict,
self.negative_current_states_dict,
predicted_current_state_dict
)
# Calculate encoder loss
encoder_loss = self._past_encoder_loss(self.current_states_dict,
predicted_current_state_dict)
#exit() #total_ub_loss += ub_loss
#total_encoder_loss += encoder_loss
# contrastive projection
vec_anchor = predicted_current_state_dict["sample"]
vec_positive = self.next_states_dict["sample"].detach()
z_anchor = self.prjoection_head(vec_anchor, self.action)
z_positive = self.prjoection_head_momentum(vec_positive, next_actions[i]).detach()
#print(total_ub_loss, total_encoder_loss) # contrastive loss
logits = self.contrastive_head(z_anchor, z_positive)
labels = labels = torch.arange(logits.shape[0]).long()
lb_loss = F.cross_entropy(logits, labels)
# update models
world_model_loss = encoder_loss + 1e-1 * ub_loss + lb_loss #1e-1 * ub_loss + 1e-5 * encoder_loss + 1e-1 * lb_loss
print("ub_loss: {:.4f}, encoder_loss: {:.4f}, lb_loss: {:.4f}".format(ub_loss, encoder_loss, lb_loss))
print("world_model_loss: {:.4f}".format(world_model_loss))
self.world_model_opt.zero_grad()
world_model_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm)
self.world_model_opt.step()
# behaviour learning
with FreezeParameters(self.world_model_modules):
imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"].detach(),
self.action, self.history.detach(),
imagine_horizon)
print(imagined_rollout["sample"].shape, imagined_rollout["distribution"][0].sample().shape)
#exit()
step += 1
if step>total_steps:
print("Training finished")
break
#exit()
#print(total_ub_loss, total_encoder_loss)
@ -315,7 +390,7 @@ class DPI:
current_states, current_states,
negative_current_states, negative_current_states,
predicted_current_states) predicted_current_states)
club_loss = club_sample() club_loss = club_sample.loglikeli()
return club_loss return club_loss
def _past_encoder_loss(self, curr_states_dict, predicted_curr_states_dict): def _past_encoder_loss(self, curr_states_dict, predicted_curr_states_dict):
@ -325,42 +400,27 @@ class DPI:
# predicted current state distribution # predicted current state distribution
predicted_curr_states_dist = predicted_curr_states_dict["distribution"] predicted_curr_states_dist = predicted_curr_states_dict["distribution"]
# KL divergence loss # KL divergence loss
loss = torch.distributions.kl.kl_divergence(curr_states_dist, predicted_curr_states_dist).mean() loss = torch.distributions.kl.kl_divergence(curr_states_dist, predicted_curr_states_dist).mean()
return loss return loss
"""
def _past_encoder_loss(self, states, next_states, states_dist, next_states_dist, actions, history, step):
# Imagine next state
if step == 0:
actions = torch.zeros(self.args.batch_size, self.env.action_space.shape[0]).float() # Zero action for first step
imagined_next_states = self.transition_model.imagine_step(states, actions, history)
self.history = imagined_next_states["history"]
else:
imagined_next_states = self.transition_model.imagine_step(states, actions, self.history) # (N,128)
# State Distribution
imagined_next_states_dist = imagined_next_states["distribution"]
# KL divergence loss
loss = torch.distributions.kl.kl_divergence(imagined_next_states_dist, next_states_dist["distribution"]).mean()
return loss
"""
def get_features(self, x, momentum=False): def get_features(self, x, momentum=False):
if self.aug: import torchvision.transforms.functional as fn
x = x/255.0 - 0.5 # Preprocessing
if self.args.aug:
x = T.RandomCrop((80, 80))(x) # (None,80,80,4) x = T.RandomCrop((80, 80))(x) # (None,80,80,4)
x = T.functional.pad(x, (4, 4, 4, 4), "symmetric") # (None,88,88,4) x = T.functional.pad(x, (4, 4, 4, 4), "symmetric") # (None,88,88,4)
x = T.RandomCrop((84, 84))(x) # (None,84,84,4) x = T.RandomCrop((84, 84))(x) # (None,84,84,4)
with torch.no_grad(): with torch.no_grad():
x = (x.float() - self.ob_mean) / self.ob_std
if momentum: if momentum:
x = self.obs_encoder(x).detach()
else:
x = self.obs_encoder_momentum(x) x = self.obs_encoder_momentum(x)
else:
x = self.obs_encoder(x)
return x return x
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -17,6 +17,7 @@ import dmc2gym
import cv2 import cv2
from PIL import Image from PIL import Image
from typing import Iterable
class eval_mode(object): class eval_mode(object):
@ -197,6 +198,12 @@ def make_env(args):
) )
return env return env
def soft_update_params(net, target_net, tau):
for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_(
tau * param.data + (1 - tau) * target_param.data
)
def save_image(array, filename): def save_image(array, filename):
array = array.transpose(1, 2, 0) array = array.transpose(1, 2, 0)
array = (array * 255).astype(np.uint8) array = (array * 255).astype(np.uint8)
@ -256,4 +263,40 @@ class CorruptVideos:
print(f"{filepath} is corrupt.") print(f"{filepath} is corrupt.")
if delete: if delete:
self._delete_corrupt_video(filepath) self._delete_corrupt_video(filepath)
print(f"Deleted {filepath}") print(f"Deleted {filepath}")
def get_parameters(modules: Iterable[nn.Module]):
"""
Given a list of torch modules, returns a list of their parameters.
:param modules: iterable of modules
:returns: a list of parameters
"""
model_parameters = []
for module in modules:
model_parameters += list(module.parameters())
return model_parameters
class FreezeParameters:
def __init__(self, modules: Iterable[nn.Module]):
"""
Context manager to locally freeze gradients.
In some cases with can speed up computation because gradients aren't calculated for these listed modules.
example:
```
with FreezeParameters([module]):
output_tensor = module(input_tensor)
```
:param modules: iterable of modules. used to call .parameters() to freeze gradients.
"""
self.modules = modules
self.param_states = [p.requires_grad for p in get_parameters(self.modules)]
def __enter__(self):
for param in get_parameters(self.modules):
param.requires_grad = False
def __exit__(self, exc_type, exc_val, exc_tb):
for i, param in enumerate(get_parameters(self.modules)):
param.requires_grad = self.param_states[i]