Add encoder loss and include tqdm for visualization
This commit is contained in:
parent
a1fe81f018
commit
11f00ad695
70
DPI/train.py
70
DPI/train.py
@ -7,6 +7,7 @@ import time
|
||||
import json
|
||||
import dmc2gym
|
||||
|
||||
import tqdm
|
||||
import wandb
|
||||
import utils
|
||||
from utils import ReplayBuffer, make_env, save_image
|
||||
@ -33,10 +34,10 @@ def parse_args():
|
||||
parser.add_argument('--resource_files', type=str)
|
||||
parser.add_argument('--eval_resource_files', type=str)
|
||||
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
||||
parser.add_argument('--total_frames', default=10000, type=int)
|
||||
parser.add_argument('--total_frames', default=1000, type=int) # 10000
|
||||
parser.add_argument('--high_noise', action='store_true')
|
||||
# replay buffer
|
||||
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #100000
|
||||
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #50000
|
||||
parser.add_argument('--episode_length', default=50, type=int)
|
||||
# train
|
||||
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
|
||||
@ -130,10 +131,6 @@ class DPI:
|
||||
self.model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model'))
|
||||
self.buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer'))
|
||||
|
||||
# create video recorder
|
||||
#video = VideoRecorder(video_dir if args.save_video else None, resource_files=args.resource_files)
|
||||
#video.init(enabled=True)
|
||||
|
||||
# create models
|
||||
self.build_models(use_saved=False, saved_model_dir=self.model_dir)
|
||||
|
||||
@ -174,28 +171,24 @@ class DPI:
|
||||
done = False
|
||||
|
||||
#video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
|
||||
for episode_count in range(episodes):
|
||||
self.env.video.init(enabled=True)
|
||||
for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'):
|
||||
#self.env.video.init(enabled=True)
|
||||
for i in range(self.args.episode_length):
|
||||
action = self.env.action_space.sample()
|
||||
next_obs, _, done, _ = self.env.step(action)
|
||||
|
||||
self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
|
||||
|
||||
if args.save_video:
|
||||
self.env.video.record(self.env)
|
||||
#if args.save_video:
|
||||
# self.env.video.record(self.env)
|
||||
|
||||
if done:
|
||||
obs = self.env.reset()
|
||||
done=False
|
||||
else:
|
||||
obs = next_obs
|
||||
self.env.video.save('%d.mp4' % episode_count)
|
||||
#self.env.video.save('%d.mp4' % episode_count)
|
||||
print("Collected {} random episodes".format(episode_count+1))
|
||||
#if args.save_video:
|
||||
# video.record(self.env)
|
||||
#video.save('%d.mp4' % step)
|
||||
#video.close()
|
||||
|
||||
def train(self):
|
||||
# collect experience
|
||||
@ -204,26 +197,59 @@ class DPI:
|
||||
# Group observations and next_observations by steps
|
||||
observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float()
|
||||
next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()
|
||||
actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()
|
||||
|
||||
# Initialize transition model states
|
||||
self.transition_model.init_states(self.args.batch_size, device="cpu") # (N,128)
|
||||
self.history = self.transition_model.prev_history # (N,128)
|
||||
|
||||
# Train encoder
|
||||
previous_information_loss = 0
|
||||
previous_encoder_loss = 0
|
||||
for i in range(self.args.episode_length):
|
||||
# Encode observations and next_observations
|
||||
self.features = self.obs_encoder(observations[i]) # (N,128)
|
||||
self.next_features = self.obs_encoder(next_observations[i]) # (N,128)
|
||||
self.states_dist = self.obs_encoder(observations[i])
|
||||
self.next_states_dist = self.obs_encoder(next_observations[i])
|
||||
|
||||
# Sample states and next_states
|
||||
self.states = self.states_dist["sample"] # (N,128)
|
||||
self.next_states = self.next_states_dist["sample"] # (N,128)
|
||||
self.actions = actions[i] # (N,6)
|
||||
|
||||
# Calculate upper bound loss
|
||||
past_loss = previous_information_loss + self.upper_bound_minimization(self.features, self.next_features)
|
||||
previous_information_loss = past_loss
|
||||
print("past_loss: ", past_loss)
|
||||
past_latent_loss = previous_information_loss + self._upper_bound_minimization(self.states, self.next_states)
|
||||
|
||||
def upper_bound_minimization(self, features, next_features):
|
||||
# Calculate encoder loss
|
||||
past_encoder_loss = previous_encoder_loss + self._past_encoder_loss(self.states, self.next_states,
|
||||
self.states_dist, self.next_states_dist,
|
||||
self.actions, self.history, i)
|
||||
|
||||
previous_information_loss = past_latent_loss
|
||||
previous_encoder_loss = past_encoder_loss
|
||||
|
||||
def _upper_bound_minimization(self, states, next_states):
|
||||
club_sample = CLUBSample(self.args.state_size,
|
||||
self.args.state_size,
|
||||
self.args.hidden_size)
|
||||
club_loss = club_sample(features, next_features)
|
||||
club_loss = club_sample(states, next_states)
|
||||
return club_loss
|
||||
|
||||
def _past_encoder_loss(self, states, next_states, states_dist, next_states_dist, actions, history, step):
|
||||
# Imagine next state
|
||||
if step == 0:
|
||||
actions = torch.zeros(self.args.batch_size, self.env.action_space.shape[0]).float() # Zero action for first step
|
||||
imagined_next_states = self.transition_model.imagine_step(states, actions, history)
|
||||
self.history = imagined_next_states["history"]
|
||||
else:
|
||||
imagined_next_states = self.transition_model.imagine_step(states, actions, self.history) # (N,128)
|
||||
|
||||
# State Distribution
|
||||
imagined_next_states_dist = imagined_next_states["distribution"]
|
||||
|
||||
# KL divergence loss
|
||||
loss = torch.distributions.kl.kl_divergence(imagined_next_states_dist, next_states_dist["distribution"]).mean()
|
||||
|
||||
return loss
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
Loading…
Reference in New Issue
Block a user