Curiosity/DPI/train.py
2023-04-18 16:47:30 +02:00

648 lines
33 KiB
Python

import os
import gc
import copy
import tqdm
import wandb
import random
import argparse
import numpy as np
import utils
from utils import ReplayBuffer, FreezeParameters, make_env, preprocess_obs, soft_update_params, save_image
from models import ObservationEncoder, ObservationDecoder, TransitionModel, Actor, ValueModel, RewardModel, ProjectionHead, ContrastiveHead, CLUBSample
from logger import Logger
from video import VideoRecorder
from dmc2gym.wrappers import set_global_var
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.tensorboard import SummaryWriter
#from agent.baseline_agent import BaselineAgent
#from agent.bisim_agent import BisimAgent
#from agent.deepmdp_agent import DeepMDPAgent
#from agents.navigation.carla_env import CarlaEnv
def parse_args():
parser = argparse.ArgumentParser()
# environment
parser.add_argument('--domain_name', default='cheetah')
parser.add_argument('--version', default=1, type=int)
parser.add_argument('--task_name', default='run')
parser.add_argument('--image_size', default=84, type=int)
parser.add_argument('--channels', default=3, type=int)
parser.add_argument('--action_repeat', default=2, type=int)
parser.add_argument('--frame_stack', default=3, type=int)
parser.add_argument('--collection_interval', default=100, type=int)
parser.add_argument('--resource_files', type=str)
parser.add_argument('--eval_resource_files', type=str)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--total_frames', default=1000, type=int) # 10000
parser.add_argument('--high_noise', action='store_true')
# replay buffer
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #50000
parser.add_argument('--episode_length', default=21, type=int)
# train
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
parser.add_argument('--init_steps', default=10000, type=int)
parser.add_argument('--num_train_steps', default=100000, type=int)
parser.add_argument('--batch_size', default=128, type=int) #512
parser.add_argument('--state_size', default=30, type=int)
parser.add_argument('--hidden_size', default=256, type=int)
parser.add_argument('--history_size', default=128, type=int)
parser.add_argument('--episode_collection', default=5, type=int)
parser.add_argument('--episodes_buffer', default=20, type=int)
parser.add_argument('--num-units', type=int, default=50, help='num hidden units for reward/value/discount models')
parser.add_argument('--load_encoder', default=None, type=str)
parser.add_argument('--imagine_horizon', default=10, type=str)
parser.add_argument('--grad_clip_norm', type=float, default=100.0, help='Gradient clipping norm')
# eval
parser.add_argument('--eval_freq', default=10, type=int) # TODO: master had 10000
parser.add_argument('--num_eval_episodes', default=20, type=int)
parser.add_argument('--evaluation_interval', default=10000, type=int) # TODO: master had 10000
# value
parser.add_argument('--value_lr', default=8e-6, type=float)
parser.add_argument('--value_beta', default=0.9, type=float)
parser.add_argument('--value_tau', default=0.005, type=float)
parser.add_argument('--value_target_update_freq', default=100, type=int)
parser.add_argument('--td_lambda', default=0.95, type=int)
# actor
parser.add_argument('--actor_lr', default=8e-6, type=float)
parser.add_argument('--actor_beta', default=0.9, type=float)
parser.add_argument('--actor_log_std_min', default=-10, type=float)
parser.add_argument('--actor_log_std_max', default=2, type=float)
parser.add_argument('--actor_update_freq', default=2, type=int)
# world/encoder/decoder
parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla096', 'pixelCarla098', 'identity'])
parser.add_argument('--world_model_lr', default=1e-5, type=float)
parser.add_argument('--encoder_tau', default=0.005, type=float)
parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity', 'contrastive', 'reward', 'inverse', 'reconstruction'])
parser.add_argument('--num_layers', default=4, type=int)
parser.add_argument('--num_filters', default=32, type=int)
parser.add_argument('--aug', action='store_true')
# sac
parser.add_argument('--discount', default=0.99, type=float)
# misc
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--logging_freq', default=100, type=int)
parser.add_argument('--saving_interval', default=2500, type=int)
parser.add_argument('--work_dir', default='.', type=str)
parser.add_argument('--save_tb', default=False, action='store_true')
parser.add_argument('--save_model', default=False, action='store_true')
parser.add_argument('--save_buffer', default=False, action='store_true')
parser.add_argument('--save_video', default=False, action='store_true')
parser.add_argument('--transition_model_type', default='', type=str, choices=['', 'deterministic', 'probabilistic', 'ensemble'])
parser.add_argument('--render', default=False, action='store_true')
args = parser.parse_args()
return args
class DPI:
def __init__(self, args):
# wandb config
#run = wandb.init(project="dpi")
self.args = args
# set environment noise
set_global_var(self.args.high_noise)
# environment setup
self.env = make_env(self.args)
#self.args.seed = np.random.randint(0, 1000)
self.env.seed(self.args.seed)
self.global_episodes = 0
# noiseless environment setup
self.args.version = 2 # env_id changes to v2
self.args.img_source = None # no image noise
self.args.resource_files = None
# stack several consecutive frames together
if self.args.encoder_type.startswith('pixel'):
self.env = utils.FrameStack(self.env, k=self.args.frame_stack)
self.env = utils.ActionRepeat(self.env, self.args.action_repeat)
self.env = utils.NormalizeActions(self.env)
# create replay buffer
self.data_buffer = ReplayBuffer(size=self.args.replay_buffer_capacity,
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size),
action_size=self.env.action_space.shape[0],
seq_len=self.args.episode_length,
batch_size=args.batch_size,
args=self.args)
# create work directory
utils.make_dir(self.args.work_dir)
self.video_dir = utils.make_dir(os.path.join(self.args.work_dir, 'video'))
self.model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model'))
self.buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer'))
# create models
self.build_models(use_saved=False, saved_model_dir=self.model_dir)
def build_models(self, use_saved, saved_model_dir=None):
# World Models
self.obs_encoder = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84)
state_size=self.args.state_size # 128
).to(device)
self.obs_encoder_momentum = ObservationEncoder(
obs_shape=(self.args.frame_stack*self.args.channels,self.args.image_size,self.args.image_size), # (9,84,84)
state_size=self.args.state_size # 128
).to(device)
self.obs_decoder = ObservationDecoder(
state_size=self.args.state_size, # 128
output_shape=(self.args.channels*self.args.channels,self.args.image_size,self.args.image_size) # (3,84,84)
).to(device)
self.transition_model = TransitionModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
action_size=self.env.action_space.shape[0], # 6
history_size=self.args.history_size, # 128
).to(device)
# Actor Model
self.actor_model = Actor(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256,
action_size=self.env.action_space.shape[0], # 6
).to(device)
self.actor_model.apply(self.init_weights)
# Value Models
self.value_model = ValueModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
).to(device)
self.target_value_model = ValueModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
).to(device)
self.reward_model = RewardModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
).to(device)
# Contrastive Models
self.prjoection_head = ProjectionHead(
state_size=self.args.state_size, # 128
action_size=self.env.action_space.shape[0], # 6
hidden_size=self.args.hidden_size, # 256
).to(device)
self.prjoection_head_momentum = ProjectionHead(
state_size=self.args.state_size, # 128
action_size=self.env.action_space.shape[0], # 6
hidden_size=self.args.hidden_size, # 256
).to(device)
self.contrastive_head = ContrastiveHead(
hidden_size=self.args.hidden_size, # 256
).to(device)
self.club_sample = CLUBSample(
x_dim=self.args.state_size, # 128
y_dim=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
).to(device)
# model parameters
self.world_model_parameters = list(self.obs_encoder.parameters()) + list(self.prjoection_head.parameters()) + \
list(self.transition_model.parameters()) + list(self.obs_decoder.parameters()) + \
list(self.reward_model.parameters()) + list(self.club_sample.parameters())
self.past_transition_parameters = self.transition_model.parameters()
# optimizers
self.world_model_opt = torch.optim.Adam(self.world_model_parameters, self.args.world_model_lr)
self.value_opt = torch.optim.Adam(self.value_model.parameters(), self.args.value_lr)
self.actor_opt = torch.optim.Adam(self.actor_model.parameters(), self.args.actor_lr)
#self.reward_opt = torch.optim.Adam(self.reward_model.parameters(), 1e-5)
#self.decoder_opt = torch.optim.Adam(self.obs_decoder.parameters(), 1e-4)
# Create Modules
self.world_model_modules = [self.obs_encoder, self.prjoection_head, self.transition_model, self.obs_decoder, self.reward_model, self.club_sample]
self.value_modules = [self.value_model]
self.actor_modules = [self.actor_model]
#self.reward_modules = [self.reward_model]
#self.decoder_modules = [self.obs_decoder]
if use_saved:
self._use_saved_models(saved_model_dir)
def _use_saved_models(self, saved_model_dir):
self.obs_encoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_encoder.pt')))
self.obs_decoder.load_state_dict(torch.load(os.path.join(saved_model_dir, 'obs_decoder.pt')))
self.transition_model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'transition_model.pt')))
def collect_sequences(self, episodes, random=True, actor_model=None, encoder_model=None):
obs = self.env.reset()
done = False
all_rews = []
for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'):
self.global_episodes += 1
epi_reward = 0
while not done:
if random:
action = self.env.action_space.sample()
else:
with torch.no_grad():
obs = torch.tensor(obs.copy(), dtype=torch.float32).unsqueeze(0)
obs_processed = preprocess_obs(obs).to(device)
state = self.obs_encoder(obs_processed)["distribution"].sample()
action = self.actor_model(state).cpu().numpy().squeeze()
#action = self.env.action_space.sample()
next_obs, rew, done, _ = self.env.step(action)
self.data_buffer.add(obs, action, next_obs, rew, done, self.global_episodes)
obs = next_obs
epi_reward += rew
obs = self.env.reset()
done=False
all_rews.append(epi_reward)
return all_rews
def train(self, step, total_steps):
counter = 0
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
while step < total_steps:
# collect experience
if step !=0:
encoder = self.obs_encoder
actor = self.actor_model
all_rews = self.collect_sequences(self.args.episode_collection, random=False, actor_model=actor, encoder_model=encoder)
else:
all_rews = self.collect_sequences(self.args.episodes_buffer, random=True)
# collect sequences
non_zero_indices = np.nonzero(self.data_buffer.episode_count)[0]
current_obs = self.data_buffer.observations[non_zero_indices]
next_obs = self.data_buffer.next_observations[non_zero_indices]
actions_raw = self.data_buffer.actions[non_zero_indices]
rewards = self.data_buffer.rewards[non_zero_indices]
self.terms = np.where(self.data_buffer.terminals[non_zero_indices]!=0)[0]
# Group by episodes
current_obs = self.grouped_arrays(current_obs)
next_obs = self.grouped_arrays(next_obs)
actions_raw = self.grouped_arrays(actions_raw)
rewards_ = self.grouped_arrays(rewards)
# Train encoder
if step == 0:
step += 1
update_steps = 1 if step > 1 else 1
#for _ in range(self.args.collection_interval // self.args.episode_length+1):
for _ in range(update_steps):
counter += 1
# Select random chunks of episodes
if current_obs.shape[0] < self.args.batch_size:
random_episode_number = np.random.randint(0, current_obs.shape[0], self.args.batch_size)
else:
random_episode_number = random.sample(range(current_obs.shape[0]), self.args.batch_size)
if current_obs[0].shape[0]-self.args.episode_length < self.args.batch_size:
init_index = np.random.randint(0, current_obs[0].shape[0]-self.args.episode_length-2, self.args.batch_size)
else:
init_index = np.asarray(random.sample(range(current_obs[0].shape[0]-self.args.episode_length), self.args.batch_size))
random.shuffle(random_episode_number)
random.shuffle(init_index)
last_observations = self.select_first_k(current_obs, init_index, random_episode_number)[:-1]
current_observations = self.select_first_k(current_obs, init_index, random_episode_number)[1:]
next_observations = self.select_first_k(next_obs, init_index, random_episode_number)[:-1]
actions = self.select_first_k(actions_raw, init_index, random_episode_number)[:-1].to(device)
next_actions = self.select_first_k(actions_raw, init_index, random_episode_number)[1:].to(device)
rewards = self.select_first_k(rewards_, init_index, random_episode_number)[1:].to(device)
# Preprocessing
last_observations = preprocess_obs(last_observations).to(device)
current_observations = preprocess_obs(current_observations).to(device)
next_observations = preprocess_obs(next_observations).to(device)
# Initialize transition model states
self.transition_model.init_states(self.args.batch_size, device) # (N,128)
self.history = self.transition_model.prev_history # (N,128)
past_world_model_loss = 0
past_action_loss = 0
past_value_loss = 0
for i in range(self.args.episode_length-1):
if i > 0:
# Encode observations and next_observations
self.last_states_dict = self.get_features(last_observations[i])
self.current_states_dict = self.get_features(current_observations[i])
self.next_states_dict = self.get_features(next_observations[i], momentum=True)
self.action = actions[i] # (N,6)
self.next_action = next_actions[i] # (N,6)
history = self.transition_model.prev_history
# Encode negative observations fro upper bound loss
idx = torch.randperm(current_observations[i].shape[0]) # random permutation on batch
random_time_index = torch.randint(0, current_observations.shape[0]-2, (1,)).item() # random time index
negative_current_observations = current_observations[random_time_index][idx]
self.negative_current_states_dict = self.get_features(negative_current_observations)
# Predict current state from past state with transition model
last_states_sample = self.last_states_dict["sample"]
predicted_current_state_dict = self.transition_model.imagine_step(last_states_sample, self.action, self.history)
self.history = predicted_current_state_dict["history"]
# Calculate upper bound loss
likeli_loss, ub_loss = self._upper_bound_minimization(self.last_states_dict["sample"].detach(),
self.current_states_dict["sample"].detach(),
self.negative_current_states_dict["sample"].detach(),
predicted_current_state_dict["sample"].detach(),
)
# Calculate encoder loss
encoder_loss = self._past_encoder_loss(self.current_states_dict, predicted_current_state_dict)
# decoder loss
horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
nxt_obs = next_observations[i:i+horizon].reshape(-1,9,84,84)
next_states_encodings = self.get_features(nxt_obs)["sample"].view(horizon,self.args.batch_size, -1)
obs_dist = self.obs_decoder(next_states_encodings)
decoder_loss = -torch.mean(obs_dist.log_prob(next_observations[i:i+horizon]))
# contrastive projection
vec_anchor = predicted_current_state_dict["sample"].detach()
vec_positive = self.next_states_dict["sample"].detach()
z_anchor = self.prjoection_head(vec_anchor, self.action)
z_positive = self.prjoection_head_momentum(vec_positive, next_actions[i]).detach()
# contrastive loss
logits = self.contrastive_head(z_anchor, z_positive)
labels = torch.arange(logits.shape[0]).long().to(device)
lb_loss = F.cross_entropy(logits, labels)
# reward loss
reward_dist = self.reward_model(self.current_states_dict["sample"])
reward_loss = -torch.mean(reward_dist.log_prob(rewards[i]))
# world model loss
world_model_loss = (10*encoder_loss + 10*ub_loss + 1e-1*lb_loss + reward_loss + 1e-3*decoder_loss + past_world_model_loss) * 1e-3
past_world_model_loss = world_model_loss.item()
# actor loss
with FreezeParameters(self.world_model_modules):
imagine_horizon = self.args.imagine_horizon #np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
action = self.actor_model(self.current_states_dict["sample"])
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"],
action, self.history,
imagine_horizon)
with FreezeParameters(self.world_model_modules + self.value_modules):
imag_rewards = self.reward_model(imagined_rollout["sample"]).mean
imag_values = self.target_value_model(imagined_rollout["sample"]).mean
discounts = self.args.discount * torch.ones_like(imag_rewards).detach()
self.returns = self._compute_lambda_return(imag_rewards[:-1],
imag_values[:-1],
discounts[:-1] ,
self.args.td_lambda,
imag_values[-1])
discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0)
self.discounts = torch.cumprod(discounts, 0).detach()
actor_loss = -torch.mean(self.discounts * self.returns) + past_action_loss
past_action_loss = actor_loss.item()
# value loss
with torch.no_grad():
value_feat = imagined_rollout["sample"][:-1].detach()
value_targ = self.returns.detach()
value_dist = self.value_model(value_feat)
value_loss = -torch.mean(self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1)) + past_value_loss
past_value_loss = value_loss.item()
# update target value
if step % self.args.value_target_update_freq == 0:
self.target_value_model = copy.deepcopy(self.value_model)
# counter for reward
#count = np.arange((counter-1) * (self.args.batch_size), (counter) * (self.args.batch_size))
count = (counter-1) * (self.args.batch_size)
if step % self.args.logging_freq:
writer.add_scalar('World Loss/World Loss', world_model_loss.detach().item(), self.data_buffer.steps)
writer.add_scalar('Main Models Loss/Encoder Loss', encoder_loss.detach().item(), self.data_buffer.steps)
writer.add_scalar('Main Models Loss/Decoder Loss', decoder_loss, self.data_buffer.steps)
writer.add_scalar('Actor Critic Loss/Actor Loss', actor_loss.detach().item(), self.data_buffer.steps)
writer.add_scalar('Actor Critic Loss/Value Loss', value_loss.detach().item(), self.data_buffer.steps)
writer.add_scalar('Actor Critic Loss/Reward Loss', reward_loss.detach().item(), self.data_buffer.steps)
writer.add_scalar('Bound Loss/Upper Bound Loss', ub_loss.detach().item(), self.data_buffer.steps)
writer.add_scalar('Bound Loss/Lower Bound Loss', lb_loss.detach().item(), self.data_buffer.steps)
step += 1
print(world_model_loss, actor_loss, value_loss)
# update actor model
self.actor_opt.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm)
self.actor_opt.step()
# update world model
self.world_model_opt.zero_grad()
world_model_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_parameters, self.args.grad_clip_norm)
self.world_model_opt.step()
# update value model
self.value_opt.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm)
self.value_opt.step()
# update momentum encoder and projection head
soft_update_params(self.obs_encoder, self.obs_encoder_momentum, self.args.encoder_tau)
soft_update_params(self.prjoection_head, self.prjoection_head_momentum, self.args.encoder_tau)
rew_len = np.arange(count, count+self.args.episode_collection) if count != 0 else np.arange(0, self.args.batch_size)
for j in range(len(all_rews)):
writer.add_scalar('Rewards/Rewards', all_rews[j], rew_len[j])
print(step)
if step % 2850 == 0 and self.data_buffer.steps!=0: #self.args.evaluation_interval == 0:
print("Saving model")
path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth"
self.save_models(path)
self.evaluate()
def evaluate(self, eval_episodes=10):
path = path = os.path.dirname(os.path.realpath(__file__)) + "/saved_models/models.pth"
self.restore_checkpoint(path)
obs = self.env.reset()
done = False
#video = VideoRecorder(self.video_dir, resource_files=self.args.resource_files)
if self.args.save_video:
self.env.video.init(enabled=True)
episodic_rewards = []
for episode in range(eval_episodes):
rewards = 0
done = False
while not done:
with torch.no_grad():
obs = torch.tensor(obs.copy(), dtype=torch.float32).unsqueeze(0)
obs_processed = preprocess_obs(obs).to(device)
state = self.obs_encoder(obs_processed)["distribution"].sample()
action = self.actor_model(state).cpu().detach().numpy().squeeze()
next_obs, rew, done, _ = self.env.step(action)
rewards += rew
if self.args.save_video:
self.env.video.record(self.env)
self.env.video.save('/home/vedant/Curiosity/Curiosity/DPI/log/video/learned_model.mp4')
obs = next_obs
obs = self.env.reset()
episodic_rewards.append(rewards)
print("Episodic rewards: ", episodic_rewards)
print("Average episodic reward: ", np.mean(episodic_rewards))
def init_weights(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def grouped_arrays(self,array):
indices = [0] + self.terms.tolist()
def subarrays():
for start, end in zip(indices[:-1], indices[1:]):
yield array[start:end]
try:
subarrays = np.stack(list(subarrays()), axis=0)
except ValueError:
subarrays = np.asarray(list(subarrays()))
return subarrays
def select_first_k(self, array, init_index, episode_number):
term_index = init_index + self.args.episode_length
array = array[episode_number]
array_list = []
for i in range(array.shape[0]):
array_list.append(array[i][init_index[i]:term_index[i]])
array = np.asarray(array_list)
if array.ndim == 5:
transposed_array = np.transpose(array, (1, 0, 2, 3, 4))
elif array.ndim == 4:
transposed_array = np.transpose(array, (1, 0, 2, 3))
elif array.ndim == 3:
transposed_array = np.transpose(array, (1, 0, 2))
elif array.ndim == 2:
transposed_array = np.transpose(array, (1, 0))
else:
transposed_array = np.expand_dims(array, axis=0)
return torch.tensor(transposed_array).float()
def _upper_bound_minimization(self, last_states, current_states, negative_current_states, predicted_current_states):
club_loss = self.club_sample(current_states, predicted_current_states, negative_current_states)
likelihood_loss = 0
return likelihood_loss, club_loss
def _past_encoder_loss(self, curr_states_dict, predicted_curr_states_dict):
# current state distribution
curr_states_dist = curr_states_dict["distribution"]
# predicted current state distribution
predicted_curr_states_dist = predicted_curr_states_dict["distribution"]
# KL divergence loss
loss = torch.mean(torch.distributions.kl.kl_divergence(curr_states_dist,predicted_curr_states_dist))
return loss
def get_features(self, x, momentum=False):
if self.args.aug:
crop_transform = T.RandomCrop(size=80)
cropped_x = torch.stack([crop_transform(x[i]) for i in range(x.size(0))])
padding = (2, 2, 2, 2)
x = F.pad(cropped_x, padding)
with torch.no_grad():
if momentum:
x = self.obs_encoder_momentum(x)
else:
x = self.obs_encoder(x)
return x
def _compute_lambda_return(self, rewards, values, discounts, td_lam, last_value):
next_values = torch.cat([values[1:], last_value.unsqueeze(0)],0)
targets = rewards + discounts * next_values * (1-td_lam)
rets =[]
last_rew = last_value
for t in range(rewards.shape[0]-1, -1, -1):
last_rew = targets[t] + discounts[t] * td_lam *(last_rew)
rets.append(last_rew)
returns = torch.flip(torch.stack(rets), [0])
return returns
def save_models(self, save_path):
torch.save(
{'rssm' : self.transition_model.state_dict(),
'actor': self.actor_model.state_dict(),
'reward_model': self.reward_model.state_dict(),
'obs_encoder': self.obs_encoder.state_dict(),
'obs_decoder': self.obs_decoder.state_dict(),
'actor_optimizer': self.actor_opt.state_dict(),
'value_optimizer': self.value_opt.state_dict(),
'world_model_optimizer': self.world_model_opt.state_dict(),}, save_path)
def restore_checkpoint(self, ckpt_path):
checkpoint = torch.load(ckpt_path)
self.transition_model.load_state_dict(checkpoint['rssm'])
self.actor_model.load_state_dict(checkpoint['actor'])
self.reward_model.load_state_dict(checkpoint['reward_model'])
self.obs_encoder.load_state_dict(checkpoint['obs_encoder'])
self.obs_decoder.load_state_dict(checkpoint['obs_decoder'])
self.world_model_opt.load_state_dict(checkpoint['world_model_optimizer'])
self.actor_opt.load_state_dict(checkpoint['actor_optimizer'])
self.value_opt.load_state_dict(checkpoint['value_optimizer'])
if __name__ == '__main__':
args = parse_args()
writer = SummaryWriter()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
step = 0
total_steps = 500000
dpi = DPI(args)
dpi.train(step,total_steps)
dpi.evaluate()