Compare commits

...

4 Commits

3 changed files with 74 additions and 59 deletions

View File

@ -88,7 +88,6 @@ class ObservationDecoder(nn.Module):
out = self.dense(features) out = self.dense(features)
out = torch.reshape(out, [-1, self.input_size, 1, 1]) out = torch.reshape(out, [-1, self.input_size, 1, 1])
out = self.convtranspose(out) out = self.convtranspose(out)
mean = torch.reshape(out, (*out_batch_shape, *self.output_shape)) mean = torch.reshape(out, (*out_batch_shape, *self.output_shape))
out_dist = torch.distributions.independent.Independent(torch.distributions.Normal(mean, 1), len(self.output_shape)) out_dist = torch.distributions.independent.Independent(torch.distributions.Normal(mean, 1), len(self.output_shape))
return out_dist return out_dist

View File

@ -7,10 +7,11 @@ import time
import json import json
import dmc2gym import dmc2gym
import copy
import tqdm import tqdm
import wandb import wandb
import utils import utils
from utils import ReplayBuffer, FreezeParameters, make_env, soft_update_params, save_image from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image
from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample
from logger import Logger from logger import Logger
from video import VideoRecorder from video import VideoRecorder
@ -19,6 +20,8 @@ from dmc2gym.wrappers import set_global_var
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.transforms as T import torchvision.transforms as T
from torch.utils.tensorboard import SummaryWriter
#from agent.baseline_agent import BaselineAgent #from agent.baseline_agent import BaselineAgent
@ -64,7 +67,7 @@ def parse_args():
parser.add_argument('--value_lr', default=1e-4, type=float) parser.add_argument('--value_lr', default=1e-4, type=float)
parser.add_argument('--value_beta', default=0.9, type=float) parser.add_argument('--value_beta', default=0.9, type=float)
parser.add_argument('--value_tau', default=0.005, type=float) parser.add_argument('--value_tau', default=0.005, type=float)
parser.add_argument('--value_target_update_freq', default=2, type=int) parser.add_argument('--value_target_update_freq', default=100, type=int)
parser.add_argument('--td_lambda', default=0.95, type=int) parser.add_argument('--td_lambda', default=0.95, type=int)
# reward # reward
parser.add_argument('--reward_lr', default=1e-4, type=float) parser.add_argument('--reward_lr', default=1e-4, type=float)
@ -80,7 +83,7 @@ def parse_args():
parser.add_argument('--world_model_lr', default=1e-3, type=float) parser.add_argument('--world_model_lr', default=1e-3, type=float)
parser.add_argument('--past_transition_lr', default=1e-3, type=float) parser.add_argument('--past_transition_lr', default=1e-3, type=float)
parser.add_argument('--encoder_lr', default=1e-3, type=float) parser.add_argument('--encoder_lr', default=1e-3, type=float)
parser.add_argument('--encoder_tau', default=0.005, type=float) parser.add_argument('--encoder_tau', default=0.001, type=float)
parser.add_argument('--encoder_stride', default=1, type=int) parser.add_argument('--encoder_stride', default=1, type=int)
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction']) parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
parser.add_argument('--decoder_lr', default=1e-3, type=float) parser.add_argument('--decoder_lr', default=1e-3, type=float)
@ -116,7 +119,7 @@ def parse_args():
class DPI: class DPI:
def __init__(self, args): def __init__(self, args, writer):
# wandb config # wandb config
#run = wandb.init(project="dpi") #run = wandb.init(project="dpi")
@ -134,13 +137,10 @@ class DPI:
self.args.version = 2 # env_id changes to v2 self.args.version = 2 # env_id changes to v2
self.args.img_source = None # no image noise self.args.img_source = None # no image noise
self.args.resource_files = None self.args.resource_files = None
self.env_clean = make_env(self.args)
self.env_clean.seed(self.args.seed)
# stack several consecutive frames together # stack several consecutive frames together
if self.args.encoder_type.startswith('pixel'): if self.args.encoder_type.startswith('pixel'):
self.env = utils.FrameStack(self.env, k=self.args.frame_stack) self.env = utils.FrameStack(self.env, k=self.args.frame_stack)
self.env_clean = utils.FrameStack(self.env_clean, k=self.args.frame_stack)
# create replay buffer # create replay buffer
self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity, self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity,
@ -162,18 +162,18 @@ class DPI:
def build_models(self, use_saved, saved_model_dir=None): def build_models(self, use_saved, saved_model_dir=None):
# World Models # World Models
self.obs_encoder = ObservationEncoder( self.obs_encoder = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84) obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84)
state_size=self.args.state_size # 128 state_size=self.args.state_size # 128
) )
self.obs_encoder_momentum = ObservationEncoder( self.obs_encoder_momentum = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (12,84,84) obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84)
state_size=self.args.state_size # 128 state_size=self.args.state_size # 128
) )
self.obs_decoder = ObservationDecoder( self.obs_decoder = ObservationDecoder(
state_size=self.args.state_size, # 128 state_size=self.args.state_size, # 128
output_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size) # (12,84,84) output_shape=(self.args.channels,self.args.image_size,self.args.image_size) # (3,84,84)
) )
self.transition_model = TransitionModel( self.transition_model = TransitionModel(
@ -251,41 +251,31 @@ class DPI:
def collect_sequences(self, episodes): def collect_sequences(self, episodes):
obs = self.env.reset() obs = self.env.reset()
self.ob_mean = np.mean(obs, 0).astype(np.float32)
self.ob_std = np.std(obs, 0).mean().astype(np.float32)
#obs_clean = self.env_clean.reset()
done = False done = False
#video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files) #video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'): for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'):
if args.save_video: if args.save_video:
self.env.video.init(enabled=True) self.env.video.init(enabled=True)
#self.env_clean.video.init(enabled=True)
for i in range(self.args.episode_length): for i in range(self.args.episode_length):
action = self.env.action_space.sample() action = self.env.action_space.sample()
next_obs, rew, done, _ = self.env.step(action) next_obs, rew, done, _ = self.env.step(action)
#next_obs_clean, _, done, _ = self.env_clean.step(action)
self.data_buffer.add(obs, action, next_obs, rew, episode_count+1, done) self.data_buffer.add(obs, action, next_obs, rew, episode_count+1, done)
#self.data_buffer_clean.add(obs_clean, action, next_obs_clean, episode_count+1, done)
if args.save_video: if args.save_video:
self.env.video.record(self.env) self.env.video.record(self.env)
#self.env_clean.video.record(self.env_clean)
if done or i == self.args.episode_length-1: if done or i == self.args.episode_length-1:
obs = self.env.reset() obs = self.env.reset()
#obs_clean = self.env_clean.reset()
done=False done=False
else: else:
obs = next_obs obs = next_obs
#obs_clean = next_obs_clean
if args.save_video: if args.save_video:
self.env.video.save('noisy/%d.mp4' % episode_count) self.env.video.save('noisy/%d.mp4' % episode_count)
#self.env_clean.video.save('clean/%d.mp4' % episode_count)
print("Collected {} random episodes".format(episode_count+1)) print("Collected {} random episodes".format(episode_count+1))
def train(self): def train(self):
@ -299,7 +289,12 @@ class DPI:
actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[:self.args.episode_length-1] actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[:self.args.episode_length-1]
next_actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[1:] next_actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[1:]
rewards = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"rewards",obs=False)).float()[1:] rewards = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"rewards",obs=False)).float()[1:]
# Preprocessing
last_observations = preprocess_obs(last_observations)
current_observations = preprocess_obs(current_observations)
next_observations = preprocess_obs(next_observations)
# Initialize transition model states # Initialize transition model states
self.transition_model.init_states(self.args.batch_size, device="cpu") # (N,128) self.transition_model.init_states(self.args.batch_size, device="cpu") # (N,128)
self.history = self.transition_model.prev_history # (N,128) self.history = self.transition_model.prev_history # (N,128)
@ -357,42 +352,29 @@ class DPI:
labels = labels = torch.arange(logits.shape[0]).long() labels = labels = torch.arange(logits.shape[0]).long()
lb_loss = F.cross_entropy(logits, labels) lb_loss = F.cross_entropy(logits, labels)
# update models
"""
print(likeli_loss)
for i in range(self.args.num_likelihood_updates):
self.past_transition_opt.zero_grad()
print(likeli_loss)
likeli_loss.backward()
nn.utils.clip_grad_norm_(self.past_transition_parameters, self.args.grad_clip_norm)
self.past_transition_opt.step()
print(encoder_loss, ub_loss, lb_loss, step)
"""
world_model_loss = encoder_loss + ub_loss + lb_loss
self.world_model_opt.zero_grad()
world_model_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm)
self.world_model_opt.step()
"""
if step % self.args.logging_freq:
metrics['Upper Bound Loss'] = ub_loss.item()
metrics['Encoder Loss'] = encoder_loss.item()
metrics["Lower Bound Loss"] = lb_loss.item()
metrics["World Model Loss"] = world_model_loss.item()
wandb.log(metrics)
"""
# behaviour learning # behaviour learning
with FreezeParameters(self.world_model_modules): with FreezeParameters(self.world_model_modules):
imagine_horizon = self.args.imagine_horizon #np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i) imagine_horizon = self.args.imagine_horizon #np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"].detach(), imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"].detach(),
self.next_action, self.history.detach(), self.next_action, self.history.detach(),
imagine_horizon) imagine_horizon)
#print(imagined_rollout["sample"].shape, imagined_rollout["distribution"][0].sample().shape)
# decoder loss
horizon = np.minimum(50-i, imagine_horizon)
obs_dist = self.obs_decoder(imagined_rollout["sample"][:horizon])
decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon][:,:,:3,:,:]))
# reward loss
reward_dist = self.reward_model(self.current_states_dict["sample"])
reward_loss = -torch.mean(reward_dist.log_prob(rewards[:-1]))
# update models
world_model_loss = encoder_loss + ub_loss + lb_loss + decoder_loss * 1e-2
self.world_model_opt.zero_grad()
world_model_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm)
self.world_model_opt.step()
# actor loss # actor loss
with FreezeParameters(self.world_model_modules + self.value_modules): with FreezeParameters(self.world_model_modules + self.value_modules):
imag_rew_dist = self.reward_model(imagined_rollout["sample"]) imag_rew_dist = self.reward_model(imagined_rollout["sample"])
@ -413,6 +395,7 @@ class DPI:
self.discounts = torch.cumprod(discounts, 0).detach() self.discounts = torch.cumprod(discounts, 0).detach()
actor_loss = -torch.mean(self.discounts * self.target_returns) actor_loss = -torch.mean(self.discounts * self.target_returns)
# update actor
self.actor_opt.zero_grad() self.actor_opt.zero_grad()
actor_loss.backward() actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm) nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm)
@ -425,18 +408,48 @@ class DPI:
value_dist = self.value_model(value_feat) value_dist = self.value_model(value_feat)
value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1)) value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1))
# update value
self.value_opt.zero_grad() self.value_opt.zero_grad()
value_loss.backward() value_loss.backward()
nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm) nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm)
self.value_opt.step() self.value_opt.step()
# update target value
if step % self.args.value_target_update_freq == 0:
self.target_value_model = copy.deepcopy(self.value_model)
# update momentum encoder
soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau)
# update momentum projection head
soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau)
step += 1 step += 1
if step % self.args.logging_freq:
writer.add_scalar('Main Loss/World Loss', world_model_loss, step)
writer.add_scalar('Main Models Loss/Encoder Loss', encoder_loss, step)
writer.add_scalar('Main Models Loss/Decoder Loss', decoder_loss, step)
writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss, step)
writer.add_scalar('Actor Critic Loss/Value Loss', value_loss, step)
writer.add_scalar('Actor Critic Loss/Reward Loss', reward_loss, step)
writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss, step)
writer.add_scalar('Bound Loss/Lower Bound Loss', lb_loss, step)
"""
if step % self.args.logging_freq:
metrics['Upper Bound Loss'] = ub_loss.item()
metrics['Encoder Loss'] = encoder_loss.item()
metrics['Decoder Loss'] = decoder_loss.item()
metrics["Lower Bound Loss"] = lb_loss.item()
metrics["World Model Loss"] = world_model_loss.item()
wandb.log(metrics)
"""
if step>total_steps: if step>total_steps:
print("Training finished") print("Training finished")
break break
#print(total_ub_loss, total_encoder_loss)
@ -463,10 +476,7 @@ class DPI:
return loss return loss
def get_features(self, x, momentum=False): def get_features(self, x, momentum=False):
import torchvision.transforms.functional as fn
x = x/255.0 - 0.5 # Preprocessing
if self.args.aug: if self.args.aug:
x = T.RandomCrop((80, 80))(x) # (None,80,80,4) x = T.RandomCrop((80, 80))(x) # (None,80,80,4)
x = T.functional.pad(x, (4, 4, 4, 4), "symmetric") # (None,88,88,4) x = T.functional.pad(x, (4, 4, 4, 4), "symmetric") # (None,88,88,4)
@ -494,6 +504,8 @@ class DPI:
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
writer = SummaryWriter()
dpi = DPI(args) dpi = DPI(args, writer)
dpi.train() dpi.train()

View File

@ -200,6 +200,10 @@ def make_env(args):
) )
return env return env
def preprocess_obs(obs):
obs = obs/255.0 - 0.5
return obs
def soft_update_params(net, target_net, tau): def soft_update_params(net, target_net, tau):
for param, target_param in zip(net.parameters(), target_net.parameters()): for param, target_param in zip(net.parameters(), target_net.parameters()):
target_param.data.copy_( target_param.data.copy_(
@ -301,4 +305,4 @@ class FreezeParameters:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
for i, param in enumerate(get_parameters(self.modules)): for i, param in enumerate(get_parameters(self.modules)):
param.requires_grad = self.param_states[i] param.requires_grad = self.param_states[i]