Compare commits

..

No commits in common. "ac714e3495daadeb60ebc61a00bafcacd6c2a453" and "8fd56ba94ded48373ce8af0c6b4b243c16642111" have entirely different histories.

3 changed files with 78 additions and 182 deletions

View File

@ -19,7 +19,7 @@ class ObservationEncoder(nn.Module):
input_channels = obs_shape[0] if i == 0 else output_channels input_channels = obs_shape[0] if i == 0 else output_channels
output_channels = num_filters * (2 ** i) output_channels = num_filters * (2 ** i)
layers.append(nn.Conv2d(in_channels=input_channels, out_channels= output_channels, kernel_size=4, stride=2)) layers.append(nn.Conv2d(in_channels=input_channels, out_channels= output_channels, kernel_size=4, stride=2))
layers.append(nn.LeakyReLU()) layers.append(nn.ReLU())
self.convs = nn.Sequential(*layers) self.convs = nn.Sequential(*layers)
@ -196,8 +196,7 @@ class TransitionModel(nn.Module):
def imagine_step(self, prev_state, prev_action, prev_history): def imagine_step(self, prev_state, prev_action, prev_history):
state_action = self.act_fn(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1))) state_action = self.act_fn(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1)))
prev_hist = prev_history.detach() history = self.history_cell(torch.cat([state_action, prev_history], dim=-1), prev_history)
history = self.history_cell(torch.cat([state_action, prev_hist], dim=-1), prev_hist)
state_prior = self.fc_state_prior(torch.cat([history, prev_state, prev_action], dim=-1)) state_prior = self.fc_state_prior(torch.cat([history, prev_state, prev_action], dim=-1))
state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1) state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1)

View File

@ -10,17 +10,14 @@ import dmc2gym
import tqdm import tqdm
import wandb import wandb
import utils import utils
from utils import ReplayBuffer, FreezeParameters, make_env, soft_update_params, save_image from utils import ReplayBuffer, make_env, save_image
from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample, Actor, ValueModel, RewardModel
from logger import Logger from logger import Logger
from video import VideoRecorder from video import VideoRecorder
from dmc2gym.wrappers import set_global_var from dmc2gym.wrappers import set_global_var
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T import torchvision.transforms as T
#from agent.baseline_agent import BaselineAgent #from agent.baseline_agent import BaselineAgent
#from agent.bisim_agent import BisimAgent #from agent.bisim_agent import BisimAgent
#from agent.deepmdp_agent import DeepMDPAgent #from agent.deepmdp_agent import DeepMDPAgent
@ -56,27 +53,23 @@ def parse_args():
parser.add_argument('--num-units', type=int, default=200, help='num hidden units for reward/value/discount models') parser.add_argument('--num-units', type=int, default=200, help='num hidden units for reward/value/discount models')
parser.add_argument('--load_encoder', default=None, type=str) parser.add_argument('--load_encoder', default=None, type=str)
parser.add_argument('--imagine_horizon', default=15, type=str) parser.add_argument('--imagine_horizon', default=15, type=str)
parser.add_argument('--grad_clip_norm', type=float, default=100.0, help='Gradient clipping norm')
# eval # eval
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000 parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
parser.add_argument('--num_eval_episodes', default=20, type=int) parser.add_argument('--num_eval_episodes', default=20, type=int)
# value # critic
parser.add_argument('--value_lr', default=1e-4, type=float) parser.add_argument('--critic_lr', default=1e-3, type=float)
parser.add_argument('--value_beta', default=0.9, type=float) parser.add_argument('--critic_beta', default=0.9, type=float)
parser.add_argument('--value_tau', default=0.005, type=float) parser.add_argument('--critic_tau', default=0.005, type=float)
parser.add_argument('--value_target_update_freq', default=2, type=int) parser.add_argument('--critic_target_update_freq', default=2, type=int)
# reward
parser.add_argument('--reward_lr', default=1e-4, type=float)
# actor # actor
parser.add_argument('--actor_lr', default=1e-4, type=float) parser.add_argument('--actor_lr', default=1e-3, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float) parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float) parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float) parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int) parser.add_argument('--actor_update_freq', default=2, type=int)
# world/encoder/decoder # encoder/decoder
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity']) parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
parser.add_argument('--encoder_feature_dim', default=50, type=int) parser.add_argument('--encoder_feature_dim', default=50, type=int)
parser.add_argument('--world_model_lr', default=1e-3, type=float)
parser.add_argument('--encoder_lr', default=1e-3, type=float) parser.add_argument('--encoder_lr', default=1e-3, type=float)
parser.add_argument('--encoder_tau', default=0.005, type=float) parser.add_argument('--encoder_tau', default=0.005, type=float)
parser.add_argument('--encoder_stride', default=1, type=int) parser.add_argument('--encoder_stride', default=1, type=int)
@ -86,7 +79,6 @@ def parse_args():
parser.add_argument('--decoder_weight_lambda', default=0.0, type=float) parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
parser.add_argument('--num_layers', default=4, type=int) parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int) parser.add_argument('--num_filters', default=32, type=int)
parser.add_argument('--aug', action='store_true')
# sac # sac
parser.add_argument('--discount', default=0.99, type=float) parser.add_argument('--discount', default=0.99, type=float)
parser.add_argument('--init_temperature', default=0.01, type=float) parser.add_argument('--init_temperature', default=0.01, type=float)
@ -162,7 +154,6 @@ class DPI:
self.build_models(use_saved=False, saved_model_dir=self.model_dir) self.build_models(use_saved=False, saved_model_dir=self.model_dir)
def build_models(self, use_saved, saved_model_dir=None): def build_models(self, use_saved, saved_model_dir=None):
# World Models
self.obs_encoder = ObservationEncoder( self.obs_encoder = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84) obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84)
state_size=self.args.state_size # 128 state_size=self.args.state_size # 128
@ -185,14 +176,12 @@ class DPI:
history_size=self.args.history_size, # 128 history_size=self.args.history_size, # 128
) )
# Actor Model self.action_model = Actor(
self.actor_model = Actor(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256, hidden_size=self.args.hidden_size, # 256,
action_size=self.env.action_space.shape[0], # 6 action_size=self.env.action_space.shape[0], # 6
) )
# Value Models
self.value_model = ValueModel( self.value_model = ValueModel(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
@ -208,38 +197,12 @@ class DPI:
hidden_size=self.args.hidden_size, # 256 hidden_size=self.args.hidden_size, # 256
) )
# Contrastive Models
self.prjoection_head = ProjectionHead(
state_size=self.args.state_size, # 128
action_size=self.env.action_space.shape[0], # 6
hidden_size=self.args.hidden_size, # 256
)
self.prjoection_head_momentum = ProjectionHead(
state_size=self.args.state_size, # 128
action_size=self.env.action_space.shape[0], # 6
hidden_size=self.args.hidden_size, # 256
)
self.contrastive_head = ContrastiveHead(
hidden_size=self.args.hidden_size, # 256
)
# model parameters # model parameters
self.world_model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_decoder.parameters()) + \ self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_encoder_momentum.parameters()) + \
list(self.value_model.parameters()) + list(self.transition_model.parameters()) + \ list(self.obs_decoder.parameters()) + list(self.transition_model.parameters())
list(self.prjoection_head.parameters())
# optimizers # optimizer
self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr) self.optimizer = torch.optim.Adam(self.model_parameters, lr=self.args.encoder_lr)
self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr)
self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr)
# Create Modules
self.world_model_modules = [self.obs_encoder, self.obs_decoder, self.value_model, self.transition_model, self.prjoection_head]
self.value_modules = [self.value_model]
self.actor_modules = [self.actor_model]
if use_saved: if use_saved:
self._use_saved_models(saved_model_dir) self._use_saved_models(saved_model_dir)
@ -251,8 +214,6 @@ class DPI:
def collect_sequences(self, episodes): def collect_sequences(self, episodes):
obs = self.env.reset() obs = self.env.reset()
self.ob_mean = np.mean(obs, 0).astype(np.float32)
self.ob_std = np.std(obs, 0).mean().astype(np.float32)
#obs_clean = self.env_clean.reset() #obs_clean = self.env_clean.reset()
done = False done = False
@ -304,15 +265,14 @@ class DPI:
self.history = self.transition_model.prev_history # (N,128) self.history = self.transition_model.prev_history # (N,128)
# Train encoder # Train encoder
step = 0 total_ub_loss = 0
total_steps = 10000 total_encoder_loss = 0
while step < total_steps:
for i in range(self.args.episode_length-1): for i in range(self.args.episode_length-1):
if i > 0: if i > 0:
# Encode observations and next_observations # Encode observations and next_observations
self.last_states_dict = self.get_features(last_observations[i]) self.last_states_dict = self.obs_encoder(last_observations[i])
self.current_states_dict = self.get_features(current_observations[i]) self.current_states_dict = self.obs_encoder(current_observations[i])
self.next_states_dict = self.get_features(next_observations[i], momentum=True) self.next_states_dict = self.obs_encoder_momentum(next_observations[i])
self.action = actions[i] # (N,6) self.action = actions[i] # (N,6)
history = self.transition_model.prev_history history = self.transition_model.prev_history
@ -327,8 +287,6 @@ class DPI:
predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history) predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history)
self.history = predicted_current_state_dict["history"] self.history = predicted_current_state_dict["history"]
# Calculate upper bound loss # Calculate upper bound loss
ub_loss = self._upper_bound_minimization(self.last_states_dict, ub_loss = self._upper_bound_minimization(self.last_states_dict,
self.current_states_dict, self.current_states_dict,
@ -340,45 +298,12 @@ class DPI:
encoder_loss = self._past_encoder_loss(self.current_states_dict, encoder_loss = self._past_encoder_loss(self.current_states_dict,
predicted_current_state_dict) predicted_current_state_dict)
#total_ub_loss += ub_loss total_ub_loss += ub_loss
#total_encoder_loss += encoder_loss total_encoder_loss += encoder_loss
# contrastive projection
vec_anchor = predicted_current_state_dict["sample"]
vec_positive = self.next_states_dict["sample"].detach()
z_anchor = self.prjoection_head(vec_anchor, self.action)
z_positive = self.prjoection_head_momentum(vec_positive, next_actions[i]).detach()
# contrastive loss
logits = self.contrastive_head(z_anchor, z_positive)
labels = labels = torch.arange(logits.shape[0]).long()
lb_loss = F.cross_entropy(logits, labels)
# update models
world_model_loss = encoder_loss + 1e-1 * ub_loss + lb_loss #1e-1 * ub_loss + 1e-5 * encoder_loss + 1e-1 * lb_loss
print("ub_loss: {:.4f}, encoder_loss: {:.4f}, lb_loss: {:.4f}".format(ub_loss, encoder_loss, lb_loss))
print("world_model_loss: {:.4f}".format(world_model_loss))
self.world_model_opt.zero_grad()
world_model_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm)
self.world_model_opt.step()
# behaviour learning
with FreezeParameters(self.world_model_modules):
imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i) imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"].detach(), imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"], self.action, self.history, imagine_horizon)
self.action, self.history.detach(),
imagine_horizon)
print(imagined_rollout["sample"].shape, imagined_rollout["distribution"][0].sample().shape)
#exit()
step += 1
if step>total_steps:
print("Training finished")
break
#exit() #exit()
#print(total_ub_loss, total_encoder_loss) #print(total_ub_loss, total_encoder_loss)
@ -390,7 +315,7 @@ class DPI:
current_states, current_states,
negative_current_states, negative_current_states,
predicted_current_states) predicted_current_states)
club_loss = club_sample.loglikeli() club_loss = club_sample()
return club_loss return club_loss
def _past_encoder_loss(self, curr_states_dict, predicted_curr_states_dict): def _past_encoder_loss(self, curr_states_dict, predicted_curr_states_dict):
@ -400,27 +325,42 @@ class DPI:
# predicted current state distribution # predicted current state distribution
predicted_curr_states_dist = predicted_curr_states_dict["distribution"] predicted_curr_states_dist = predicted_curr_states_dict["distribution"]
# KL divergence loss # KL divergence loss
loss = torch.distributions.kl.kl_divergence(curr_states_dist, predicted_curr_states_dist).mean() loss = torch.distributions.kl.kl_divergence(curr_states_dist, predicted_curr_states_dist).mean()
return loss return loss
def get_features(self, x, momentum=False): """
import torchvision.transforms.functional as fn def _past_encoder_loss(self, states, next_states, states_dist, next_states_dist, actions, history, step):
x = x/255.0 - 0.5 # Preprocessing # Imagine next state
if step == 0:
actions = torch.zeros(self.args.batch_size, self.env.action_space.shape[0]).float() # Zero action for first step
imagined_next_states = self.transition_model.imagine_step(states, actions, history)
self.history = imagined_next_states["history"]
else:
imagined_next_states = self.transition_model.imagine_step(states, actions, self.history) # (N,128)
if self.args.aug: # State Distribution
imagined_next_states_dist = imagined_next_states["distribution"]
# KL divergence loss
loss = torch.distributions.kl.kl_divergence(imagined_next_states_dist, next_states_dist["distribution"]).mean()
return loss
"""
def get_features(self, x, momentum=False):
if self.aug:
x = T.RandomCrop((80, 80))(x) # (None,80,80,4) x = T.RandomCrop((80, 80))(x) # (None,80,80,4)
x = T.functional.pad(x, (4, 4, 4, 4), "symmetric") # (None,88,88,4) x = T.functional.pad(x, (4, 4, 4, 4), "symmetric") # (None,88,88,4)
x = T.RandomCrop((84, 84))(x) # (None,84,84,4) x = T.RandomCrop((84, 84))(x) # (None,84,84,4)
with torch.no_grad(): with torch.no_grad():
x = (x.float() - self.ob_mean) / self.ob_std
if momentum: if momentum:
x = self.obs_encoder_momentum(x) x = self.obs_encoder(x).detach()
else: else:
x = self.obs_encoder(x) x = self.obs_encoder_momentum(x)
return x return x
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -17,7 +17,6 @@ import dmc2gym
import cv2 import cv2
from PIL import Image from PIL import Image
from typing import Iterable
class eval_mode(object): class eval_mode(object):
@ -198,12 +197,6 @@ def make_env(args):
) )
return env return env
def soft_update_params(net, target_net, tau):
for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_(
tau * param.data + (1 - tau) * target_param.data
)
def save_image(array, filename): def save_image(array, filename):
array = array.transpose(1, 2, 0) array = array.transpose(1, 2, 0)
array = (array * 255).astype(np.uint8) array = (array * 255).astype(np.uint8)
@ -264,39 +257,3 @@ class CorruptVideos:
if delete: if delete:
self._delete_corrupt_video(filepath) self._delete_corrupt_video(filepath)
print(f"Deleted {filepath}") print(f"Deleted {filepath}")
def get_parameters(modules: Iterable[nn.Module]):
"""
Given a list of torch modules, returns a list of their parameters.
:param modules: iterable of modules
:returns: a list of parameters
"""
model_parameters = []
for module in modules:
model_parameters += list(module.parameters())
return model_parameters
class FreezeParameters:
def __init__(self, modules: Iterable[nn.Module]):
"""
Context manager to locally freeze gradients.
In some cases with can speed up computation because gradients aren't calculated for these listed modules.
example:
```
with FreezeParameters([module]):
output_tensor = module(input_tensor)
```
:param modules: iterable of modules. used to call .parameters() to freeze gradients.
"""
self.modules = modules
self.param_states = [p.requires_grad for p in get_parameters(self.modules)]
def __enter__(self):
for param in get_parameters(self.modules):
param.requires_grad = False
def __exit__(self, exc_type, exc_val, exc_tb):
for i, param in enumerate(get_parameters(self.modules)):
param.requires_grad = self.param_states[i]