Compare commits

..

No commits in common. "9aa07fed6af3dea505ab65a77154663e38c6432f" and "3e9d8f7a9c5022a59a5db1403bf0e29fb449866a" have entirely different histories.

2 changed files with 214 additions and 315 deletions

View File

@ -1,12 +1,15 @@
import numpy as np
import torch
import argparse
import os
import gc
import gym
import time
import json
import dmc2gym
import copy
import tqdm
import wandb
import random
import argparse
import numpy as np
import utils
from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image
from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample
@ -14,12 +17,13 @@ from logger import Logger
from video import VideoRecorder
from dmc2gym.wrappers import set_global_var
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.tensorboard import SummaryWriter
#from agent.baseline_agent import BaselineAgent
#from agent.bisim_agent import BisimAgent
#from agent.deepmdp_agent import DeepMDPAgent
@ -34,9 +38,8 @@ def parse_args():
parser.add_argument('--task_name', default='run')
parser.add_argument('--image_size', default=84, type=int)
parser.add_argument('--channels', default=3, type=int)
parser.add_argument('--action_repeat', default=2, type=int)
parser.add_argument('--action_repeat', default=1, type=int)
parser.add_argument('--frame_stack', default=3, type=int)
parser.add_argument('--collection_interval', default=100, type=int)
parser.add_argument('--resource_files', type=str)
parser.add_argument('--eval_resource_files', type=str)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
@ -49,11 +52,11 @@ def parse_args():
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
parser.add_argument('--init_steps', default=10000, type=int)
parser.add_argument('--num_train_steps', default=10000, type=int)
parser.add_argument('--batch_size', default=30, type=int) #512
parser.add_argument('--batch_size', default=20, type=int) #512
parser.add_argument('--state_size', default=256, type=int)
parser.add_argument('--hidden_size', default=128, type=int)
parser.add_argument('--history_size', default=128, type=int)
parser.add_argument('--num-units', type=int, default=50, help='num hidden units for reward/value/discount models')
parser.add_argument('--num-units', type=int, default=200, help='num hidden units for reward/value/discount models')
parser.add_argument('--load_encoder', default=None, type=str)
parser.add_argument('--imagine_horizon', default=15, type=str)
parser.add_argument('--grad_clip_norm', type=float, default=100.0, help='Gradient clipping norm')
@ -61,13 +64,15 @@ def parse_args():
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
parser.add_argument('--num_eval_episodes', default=20, type=int)
# value
parser.add_argument('--value_lr', default=8e-5, type=float)
parser.add_argument('--value_lr', default=1e-4, type=float)
parser.add_argument('--value_beta', default=0.9, type=float)
parser.add_argument('--value_tau', default=0.005, type=float)
parser.add_argument('--value_target_update_freq', default=100, type=int)
parser.add_argument('--td_lambda', default=0.95, type=int)
# reward
parser.add_argument('--reward_lr', default=1e-4, type=float)
# actor
parser.add_argument('--actor_lr', default=8e-5, type=float)
parser.add_argument('--actor_lr', default=1e-4, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float)
@ -75,7 +80,7 @@ def parse_args():
# world/encoder/decoder
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
parser.add_argument('--encoder_feature_dim', default=50, type=int)
parser.add_argument('--world_model_lr', default=6e-4, type=float)
parser.add_argument('--world_model_lr', default=1e-3, type=float)
parser.add_argument('--past_transition_lr', default=1e-3, type=float)
parser.add_argument('--encoder_lr', default=1e-3, type=float)
parser.add_argument('--encoder_tau', default=0.001, type=float)
@ -95,7 +100,6 @@ def parse_args():
# misc
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--logging_freq', default=100, type=int)
parser.add_argument('--saving_interval', default=1000, type=int)
parser.add_argument('--work_dir', default='.', type=str)
parser.add_argument('--save_tb', default=False, action='store_true')
parser.add_argument('--save_model', default=False, action='store_true')
@ -103,6 +107,8 @@ def parse_args():
parser.add_argument('--save_video', default=False, action='store_true')
parser.add_argument('--transition_model_type', default='', type=str, choices=['', 'deterministic', 'probabilistic', 'ensemble'])
parser.add_argument('--render', default=False, action='store_true')
parser.add_argument('--port', default=2000, type=int)
parser.add_argument('--num_likelihood_updates', default=5, type=int)
args = parser.parse_args()
return args
@ -113,7 +119,7 @@ def parse_args():
class DPI:
def __init__(self, args):
def __init__(self, args, writer):
# wandb config
#run = wandb.init(project="dpi")
@ -135,8 +141,6 @@ class DPI:
# stack several consecutive frames together
if self.args.encoder_type.startswith('pixel'):
self.env = utils.FrameStack(self.env, k=self.args.frame_stack)
self.env = utils.ActionRepeat(self.env, self.args.action_repeat)
self.env = utils.NormalizeActions(self.env)
# create replay buffer
self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity,
@ -160,64 +164,64 @@ class DPI:
self.obs_encoder = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84)
state_size=self.args.state_size # 128
).to(device)
)
self.obs_encoder_momentum = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84)
state_size=self.args.state_size # 128
).to(device)
)
self.obs_decoder = ObservationDecoder(
state_size=self.args.state_size, # 128
output_shape=(self.args.channels,self.args.image_size,self.args.image_size) # (3,84,84)
).to(device)
)
self.transition_model = TransitionModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
action_size=self.env.action_space.shape[0], # 6
history_size=self.args.history_size, # 128
).to(device)
)
# Actor Model
self.actor_model = Actor(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256,
action_size=self.env.action_space.shape[0], # 6
).to(device)
)
# Value Models
self.value_model = ValueModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
).to(device)
)
self.target_value_model = ValueModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
).to(device)
)
self.reward_model = RewardModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
).to(device)
)
# Contrastive Models
self.prjoection_head = ProjectionHead(
state_size=self.args.state_size, # 128
action_size=self.env.action_space.shape[0], # 6
hidden_size=self.args.hidden_size, # 256
).to(device)
)
self.prjoection_head_momentum = ProjectionHead(
state_size=self.args.state_size, # 128
action_size=self.env.action_space.shape[0], # 6
hidden_size=self.args.hidden_size, # 256
).to(device)
)
self.contrastive_head = ContrastiveHead(
hidden_size=self.args.hidden_size, # 256
).to(device)
)
# model parameters
@ -233,7 +237,7 @@ class DPI:
self.past_transition_opt = torch.optim.Adam(self.past_transition_parameters, self.args.past_transition_lr)
# Create Modules
self.world_model_modules = [self.obs_encoder, self.obs_decoder, self.reward_model, self.transition_model, self.prjoection_head]
self.world_model_modules = [self.obs_encoder, self.obs_decoder, self.value_model, self.transition_model, self.prjoection_head]
self.value_modules = [self.value_model]
self.actor_modules = [self.actor_model]
@ -245,27 +249,21 @@ class DPI:
self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt')))
self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt')))
def collect_sequences(self, episodes, random=True, actor_model=None, encoder_model=None):
def collect_sequences(self, episodes):
obs = self.env.reset()
done = False
all_rews = []
#video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'):
if args.save_video:
self.env.video.init(enabled=True)
epi_reward = 0
for i in range(self.args.episode_length):
if random:
action = self.env.action_space.sample()
else:
with torch.no_grad():
obs_torch = torch.unsqueeze(torch.tensor(obs).float(),0).to(device)
state = self.obs_encoder(obs_torch)["distribution"].sample()
action = self.actor_model(state).cpu().detach().numpy().squeeze()
action = self.env.action_space.sample()
next_obs, rew, done, _ = self.env.step(action)
self.data_buffer.add(obs, action, next_obs, rew, episode_count+1, done)
if args.save_video:
@ -276,222 +274,184 @@ class DPI:
done=False
else:
obs = next_obs
epi_reward += rew
all_rews.append(epi_reward)
if args.save_video:
self.env.video.save('noisy/%d.mp4' % episode_count)
print("Collected {} random episodes".format(episode_count+1))
return all_rews
def train(self, step, total_steps):
counter = 0
def train(self):
# collect experience
self.collect_sequences(self.args.batch_size)
# Group observations and next_observations by steps from past to present
last_observations = torch.tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float()[:self.args.episode_length-1]
current_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()[:self.args.episode_length-1]
next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()[1:]
actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[:self.args.episode_length-1]
next_actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()[1:]
rewards = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"rewards",obs=False)).float()[1:]
# Preprocessing
last_observations = preprocess_obs(last_observations)
current_observations = preprocess_obs(current_observations)
next_observations = preprocess_obs(next_observations)
# Initialize transition model states
self.transition_model.init_states(self.args.batch_size, device="cpu") # (N,128)
self.history = self.transition_model.prev_history # (N,128)
# Train encoder
step = 0
total_steps = 10000
metrics = {}
while step < total_steps:
for i in range(self.args.episode_length-1):
if i > 0:
# Encode observations and next_observations
self.last_states_dict = self.get_features(last_observations[i])
self.current_states_dict = self.get_features(current_observations[i])
self.next_states_dict = self.get_features(next_observations[i], momentum=True)
self.action = actions[i] # (N,6)
self.next_action = next_actions[i] # (N,6)
history = self.transition_model.prev_history
# collect experience
if step !=0:
encoder = self.obs_encoder
actor = self.actor_model
#all_rews = self.collect_sequences(self.args.batch_size, random=True)
all_rews = self.collect_sequences(self.args.batch_size, random=False, actor_model=actor, encoder_model=encoder)
else:
all_rews = self.collect_sequences(self.args.batch_size, random=True)
# Encode negative observations
idx = torch.randperm(current_observations[i].shape[0]) # random permutation on batch
random_time_index = torch.randint(0, self.args.episode_length-2, (1,)).item() # random time index
negative_current_observations = current_observations[random_time_index][idx]
self.negative_current_states_dict = self.obs_encoder(negative_current_observations)
# Group by steps and sample random batch
random_indices = self.data_buffer.sample_random_idx(self.args.batch_size * ((step//self.args.collection_interval)+1)) # random indices for batch
#random_indices = np.arange(self.args.batch_size * ((step//self.args.collection_interval)),self.args.batch_size * ((step//self.args.collection_interval)+1))
last_observations = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"observations", "cpu", random_indices=random_indices)
current_observations = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"next_observations", device="cpu", random_indices=random_indices)
next_observations = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"next_observations", device="cpu", offset=1, random_indices=random_indices)
actions = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"actions", device=device, is_obs=False, random_indices=random_indices)
next_actions = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"actions", device=device, is_obs=False, offset=1, random_indices=random_indices)
rewards = self.data_buffer.group_and_sample_random_batch(self.data_buffer,"rewards", device=device, is_obs=False, offset=1, random_indices=random_indices)
# Predict current state from past state with transition model
last_states_sample = self.last_states_dict["sample"]
predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history)
self.history = predicted_current_state_dict["history"]
# Preprocessing
last_observations = preprocess_obs(last_observations).to(device)
current_observations = preprocess_obs(current_observations).to(device)
next_observations = preprocess_obs(next_observations).to(device)
# Calculate upper bound loss
likeli_loss, ub_loss = self._upper_bound_minimization(self.last_states_dict,
self.current_states_dict,
self.negative_current_states_dict,
predicted_current_state_dict
)
#likeli_loss = torch.tensor(likeli_loss.numpy(),dtype=torch.float32, requires_grad=True)
#ikeli_loss = likeli_loss.mean()
# Initialize transition model states
self.transition_model.init_states(self.args.batch_size, device) # (N,128)
self.history = self.transition_model.prev_history # (N,128)
# Calculate encoder loss
encoder_loss = self._past_encoder_loss(self.current_states_dict,
predicted_current_state_dict)
# Train encoder
if step == 0:
step += 1
for _ in range(self.args.collection_interval // self.args.episode_length+1):
counter += 1
for i in range(self.args.episode_length-1):
if i > 0:
# Encode observations and next_observations
self.last_states_dict = self.get_features(last_observations[i])
self.current_states_dict = self.get_features(current_observations[i])
self.next_states_dict = self.get_features(next_observations[i], momentum=True)
self.action = actions[i] # (N,6)
self.next_action = next_actions[i] # (N,6)
history = self.transition_model.prev_history
#total_ub_loss += ub_loss
#total_encoder_loss += encoder_loss
# Encode negative observations
idx = torch.randperm(current_observations[i].shape[0]) # random permutation on batch
random_time_index = torch.randint(0, self.args.episode_length-2, (1,)).item() # random time index
negative_current_observations = current_observations[random_time_index][idx]
self.negative_current_states_dict = self.obs_encoder(negative_current_observations)
# contrastive projection
vec_anchor = predicted_current_state_dict["sample"]
vec_positive = self.next_states_dict["sample"].detach()
z_anchor = self.prjoection_head(vec_anchor, self.action)
z_positive = self.prjoection_head_momentum(vec_positive, next_actions[i]).detach()
# Predict current state from past state with transition model
last_states_sample = self.last_states_dict["sample"]
predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history)
self.history = predicted_current_state_dict["history"]
# contrastive loss
logits = self.contrastive_head(z_anchor, z_positive)
labels = labels = torch.arange(logits.shape[0]).long()
lb_loss = F.cross_entropy(logits, labels)
# Calculate upper bound loss
likeli_loss, ub_loss = self._upper_bound_minimization(self.last_states_dict,
self.current_states_dict,
self.negative_current_states_dict,
predicted_current_state_dict
)
# behaviour learning
with FreezeParameters(self.world_model_modules):
imagine_horizon = self.args.imagine_horizon #np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"].detach(),
self.next_action, self.history.detach(),
imagine_horizon)
# Calculate encoder loss
encoder_loss = self._past_encoder_loss(self.current_states_dict,
predicted_current_state_dict)
# decoder loss
horizon = np.minimum(50-i, imagine_horizon)
obs_dist = self.obs_decoder(imagined_rollout["sample"][:horizon])
decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon][:,:,:3,:,:]))
# contrastive projection
vec_anchor = predicted_current_state_dict["sample"]
vec_positive = self.next_states_dict["sample"].detach()
z_anchor = self.prjoection_head(vec_anchor, self.action)
z_positive = self.prjoection_head_momentum(vec_positive, next_actions[i]).detach()
# reward loss
reward_dist = self.reward_model(self.current_states_dict["sample"])
reward_loss = -torch.mean(reward_dist.log_prob(rewards[:-1]))
# contrastive loss
logits = self.contrastive_head(z_anchor, z_positive)
labels = torch.arange(logits.shape[0]).long().to(device)
lb_loss = F.cross_entropy(logits, labels)
# update models
world_model_loss = encoder_loss + ub_loss + lb_loss + decoder_loss * 1e-2
self.world_model_opt.zero_grad()
world_model_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm)
self.world_model_opt.step()
# behaviour learning
with FreezeParameters(self.world_model_modules):
imagine_horizon = self.args.imagine_horizon #np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"].detach(),
self.next_action, self.history.detach(),
imagine_horizon)
# actor loss
with FreezeParameters(self.world_model_modules + self.value_modules):
imag_rew_dist = self.reward_model(imagined_rollout["sample"])
target_imag_val_dist = self.target_value_model(imagined_rollout["sample"])
# decoder loss
horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
obs_dist = self.obs_decoder(imagined_rollout["sample"][:horizon])
decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon][:,:,:3,:,:]))
imag_rews = imag_rew_dist.mean
target_imag_vals = target_imag_val_dist.mean
# reward loss
reward_dist = self.reward_model(self.current_states_dict["sample"])
reward_loss = -torch.mean(reward_dist.log_prob(rewards[:-1]))
discounts = self.args.discount * torch.ones_like(imag_rews).detach()
# update models
world_model_loss = encoder_loss + 100 * ub_loss + lb_loss + reward_loss + decoder_loss * 1e-2
self.world_model_opt.zero_grad()
world_model_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm)
self.world_model_opt.step()
self.target_returns = self._compute_lambda_return(imag_rews[:-1],
target_imag_vals[:-1],
discounts[:-1] ,
self.args.td_lambda,
target_imag_vals[-1])
# update momentum encoder
soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau)
discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0)
self.discounts = torch.cumprod(discounts, 0).detach()
actor_loss = -torch.mean(self.discounts * self.target_returns)
# update momentum projection head
soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau)
# update actor
self.actor_opt.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm)
self.actor_opt.step()
# actor loss
with FreezeParameters(self.world_model_modules + self.value_modules):
imag_rew_dist = self.reward_model(imagined_rollout["sample"])
target_imag_val_dist = self.target_value_model(imagined_rollout["sample"])
# value loss
with torch.no_grad():
value_feat = imagined_rollout["sample"][:-1].detach()
value_targ = self.target_returns.detach()
imag_rews = imag_rew_dist.mean
target_imag_vals = target_imag_val_dist.mean
value_dist = self.value_model(value_feat)
value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1))
discounts = self.args.discount * torch.ones_like(imag_rews).detach()
# update value
self.value_opt.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm)
self.value_opt.step()
self.target_returns = self._compute_lambda_return(imag_rews[:-1],
target_imag_vals[:-1],
discounts[:-1] ,
self.args.td_lambda,
target_imag_vals[-1])
# update target value
if step % self.args.value_target_update_freq == 0:
self.target_value_model = copy.deepcopy(self.value_model)
discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0)
self.discounts = torch.cumprod(discounts, 0).detach()
actor_loss = -torch.mean(self.discounts * self.target_returns)
# update momentum encoder
soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau)
# update actor
self.actor_opt.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm)
self.actor_opt.step()
# update momentum projection head
soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau)
# value loss
with torch.no_grad():
value_feat = imagined_rollout["sample"][:-1].detach()
value_targ = self.target_returns.detach()
step += 1
value_dist = self.value_model(value_feat)
value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1))
if step % self.args.logging_freq:
writer.add_scalar('Main Loss/World Loss', world_model_loss, step)
writer.add_scalar('Main Models Loss/Encoder Loss', encoder_loss, step)
writer.add_scalar('Main Models Loss/Decoder Loss', decoder_loss, step)
writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss, step)
writer.add_scalar('Actor Critic Loss/Value Loss', value_loss, step)
writer.add_scalar('Actor Critic Loss/Reward Loss', reward_loss, step)
writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss, step)
writer.add_scalar('Bound Loss/Lower Bound Loss', lb_loss, step)
# update value
self.value_opt.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm)
self.value_opt.step()
"""
if step % self.args.logging_freq:
metrics['Upper Bound Loss'] = ub_loss.item()
metrics['Encoder Loss'] = encoder_loss.item()
metrics['Decoder Loss'] = decoder_loss.item()
metrics["Lower Bound Loss"] = lb_loss.item()
metrics["World Model Loss"] = world_model_loss.item()
wandb.log(metrics)
"""
# update target value
if step % self.args.value_target_update_freq == 0:
self.target_value_model = copy.deepcopy(self.value_model)
# counter for reward
count = np.arange((counter-1) * (self.args.batch_size), (counter) * (self.args.batch_size))
if step>total_steps:
print("Training finished")
break
if step % self.args.logging_freq:
writer.add_scalar('World Loss/World Loss', world_model_loss.detach().item(), step)
writer.add_scalar('Main Models Loss/Encoder Loss', encoder_loss.detach().item(), step)
writer.add_scalar('Main Models Loss/Decoder Loss', decoder_loss, step)
writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss.detach().item(), step)
writer.add_scalar('Actor Critic Loss/Value Loss', value_loss.detach().item(), step)
writer.add_scalar('Actor Critic Loss/Reward Loss', reward_loss.detach().item(), step)
writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss.detach().item(), step)
writer.add_scalar('Bound Loss/Lower Bound Loss', lb_loss.detach().item(), step)
step += 1
if step>total_steps:
print("Training finished")
break
# save model
if step % self.args.saving_interval == 0:
path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth"
self.save_models(path)
#torch.cuda.empty_cache() # memory leak issues
for j in range(len(all_rews)):
writer.add_scalar('Rewards/Rewards', all_rews[j], count[j])
def evaluate(self, env, eval_episodes, render=False):
episode_rew = np.zeros((eval_episodes))
video_images = [[] for _ in range(eval_episodes)]
for i in range(eval_episodes):
obs = env.reset()
done = False
prev_state = self.rssm.init_state(1, self.device)
prev_action = torch.zeros(1, self.action_size).to(self.device)
while not done:
with torch.no_grad():
posterior, action = self.act_with_world_model(obs, prev_state, prev_action)
action = action[0].cpu().numpy()
next_obs, rew, done, _ = env.step(action)
prev_state = posterior
prev_action = torch.tensor(action, dtype=torch.float32).to(self.device).unsqueeze(0)
episode_rew[i] += rew
if render:
video_images[i].append(obs['image'].transpose(1,2,0).copy())
obs = next_obs
return episode_rew, np.array(video_images[:self.args.max_videos_to_save])
def _upper_bound_minimization(self, last_states, current_states, negative_current_states, predicted_current_states):
club_sample = CLUBSample(last_states,
@ -509,6 +469,8 @@ class DPI:
# predicted current state distribution
predicted_curr_states_dist = predicted_curr_states_dict["distribution"]
# KL divergence loss
loss = torch.distributions.kl.kl_divergence(curr_states_dist, predicted_curr_states_dist).mean()
@ -540,26 +502,10 @@ class DPI:
returns = torch.flip(torch.stack(rets), [0])
return returns
def save_models(self, save_path):
torch.save(
{'rssm' : self.transition_model.state_dict(),
'actor': self.actor_model.state_dict(),
'reward_model': self.reward_model.state_dict(),
'obs_encoder': self.obs_encoder.state_dict(),
'obs_decoder': self.obs_decoder.state_dict(),
'actor_optimizer': self.actor_opt.state_dict(),
'value_optimizer': self.value_opt.state_dict(),
'world_model_optimizer': self.world_model_opt.state_dict(),}, save_path)
if __name__ == '__main__':
args = parse_args()
writer = SummaryWriter()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
step = 0
total_steps = 10000
dpi = DPI(args)
dpi.train(step,total_steps)
dpi = DPI(args, writer)
dpi.train()

View File

@ -1,3 +1,9 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import random
import numpy as np
@ -102,49 +108,6 @@ class FrameStack(gym.Wrapper):
return np.concatenate(list(self._frames), axis=0)
class ActionRepeat:
def __init__(self, env, amount):
self._env = env
self._amount = amount
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
done = False
total_reward = 0
current_step = 0
while current_step < self._amount and not done:
obs, reward, done, info = self._env.step(action)
total_reward += reward
current_step += 1
return obs, total_reward, done, info
class NormalizeActions:
def __init__(self, env):
self._env = env
self._mask = np.logical_and(
np.isfinite(env.action_space.low),
np.isfinite(env.action_space.high))
self._low = np.where(self._mask, env.action_space.low, -1)
self._high = np.where(self._mask, env.action_space.high, 1)
def __getattr__(self, name):
return getattr(self._env, name)
@property
def action_space(self):
low = np.where(self._mask, -np.ones_like(self._low), self._low)
high = np.where(self._mask, np.ones_like(self._low), self._high)
return gym.spaces.Box(low, high, dtype=np.float32)
def step(self, action):
original = (action + 1) / 2 * (self._high - self._low) + self._low
original = np.where(self._mask, original, action)
return self._env.step(original)
class ReplayBuffer:
def __init__(self, size, obs_shape, action_size, seq_len, batch_size, args):
self.size = size
@ -201,11 +164,11 @@ class ReplayBuffer:
non_zero_indices = np.nonzero(buffer.episode_count)[0]
variable = variable[non_zero_indices]
if obs:
variable = variable.reshape(-1, self.args.episode_length,
variable = variable.reshape(self.args.batch_size, self.args.episode_length,
self.args.frame_stack*self.args.channels,
self.args.image_size,self.args.image_size).transpose(1, 0, 2, 3, 4)
else:
variable = variable.reshape(variable.shape[0]//self.args.episode_length, self.args.episode_length, -1).transpose(1, 0, 2)
variable = variable.reshape(self.args.batch_size, self.args.episode_length, -1).transpose(1, 0, 2)
return variable
def transform_grouped_steps(self, variable):
@ -214,16 +177,6 @@ class ReplayBuffer:
self.args.image_size,self.args.image_size)
return variable
def sample_random_idx(self, buffer_length):
random_indices = random.sample(range(0, buffer_length), self.args.batch_size)
return random_indices
def group_and_sample_random_batch(self, buffer, variable_name, device, random_indices, is_obs=True, offset=0):
if offset == 0:
variable_tensor = torch.tensor(self.group_steps(buffer,variable_name, is_obs)).float()[:self.args.episode_length-1].to(device)
else:
variable_tensor = torch.tensor(self.group_steps(buffer,variable_name, is_obs)).float()[offset:].to(device)
return variable_tensor[:,random_indices,:,:,:] if is_obs else variable_tensor[:,random_indices,:]
def make_env(args):
# For making ground plane transparent, change rgba to (0, 0, 0, 0) in local_dm_control_suite/{domain_name}.xml,