Adding new background videos for each episode

This commit is contained in:
Vedant Dave 2023-03-25 14:18:07 +01:00
parent 43f862ee6d
commit 8464503dd8
4 changed files with 33 additions and 6 deletions

View File

@ -131,12 +131,16 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
self.p_mu = nn.Sequential( self.p_mu = nn.Sequential(
nn.Linear(x_dim, hidden_size//2), nn.Linear(x_dim, hidden_size//2),
nn.ReLU(), nn.ReLU(),
nn.Linear(hidden_size//2, hidden_size//2),
nn.ReLU(),
nn.Linear(hidden_size//2, y_dim) nn.Linear(hidden_size//2, y_dim)
) )
self.p_logvar = nn.Sequential( self.p_logvar = nn.Sequential(
nn.Linear(x_dim, hidden_size//2), nn.Linear(x_dim, hidden_size//2),
nn.ReLU(), nn.ReLU(),
nn.Linear(hidden_size//2, hidden_size//2),
nn.ReLU(),
nn.Linear(hidden_size//2, y_dim), nn.Linear(hidden_size//2, y_dim),
nn.Tanh() nn.Tanh()
) )
@ -183,4 +187,3 @@ if __name__ == "__main__":
print(x_enc.shape) print(x_enc.shape)
print(y_enc.shape) print(y_enc.shape)
print(club.learning_loss(x_enc, y_enc)) print(club.learning_loss(x_enc, y_enc))

View File

@ -169,20 +169,26 @@ class DPI:
done = False done = False
for episode_count in range(episodes): for episode_count in range(episodes):
video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
video.init(enabled=True)
for i in range(self.args.episode_length): for i in range(self.args.episode_length):
action = self.env.action_space.sample() action = self.env.action_space.sample()
next_obs, _, done, _ = self.env.step(action) next_obs, _, done, _ = self.env.step(action)
self.data_buffer.add(obs, action, next_obs, episode_count+1, done) self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
if args.save_video:
video.record(self.env)
if done: if done:
obs = self.env.reset() obs = self.env.reset()
done=False done=False
else: else:
obs = next_obs obs = next_obs
video.save('%d.mp4' % episode_count)
print("Collected {} random episodes".format(episode_count+1)) print("Collected {} random episodes".format(episode_count+1))
#if args.save_video: #if args.save_video:
# video.record(env) # video.record(self.env)
#video.save('%d.mp4' % step) #video.save('%d.mp4' % step)
#video.close() #video.close()
@ -195,13 +201,16 @@ class DPI:
next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float() next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()
# Train encoder # Train encoder
previous_information_loss = 0
for i in range(self.args.episode_length): for i in range(self.args.episode_length):
# Encode observations and next_observations # Encode observations and next_observations
self.features = self.obs_encoder(observations[i]) # (N,128) self.features = self.obs_encoder(observations[i]) # (N,128)
self.next_features = self.obs_encoder(next_observations[i]) # (N,128) self.next_features = self.obs_encoder(next_observations[i]) # (N,128)
# Calculate upper bound loss # Calculate upper bound loss
past_loss = self.upper_bound_minimization(self.features, self.next_features) past_loss = previous_information_loss + self.upper_bound_minimization(self.features, self.next_features)
previous_information_loss = past_loss
print("past_loss: ", past_loss)
def upper_bound_minimization(self, features, next_features): def upper_bound_minimization(self, features, next_features):
club_sample = CLUBSample(self.args.state_size, club_sample = CLUBSample(self.args.state_size,

View File

@ -1,3 +1,9 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os import os
import torch import torch
import numpy as np import numpy as np
@ -166,6 +172,9 @@ class ReplayBuffer:
def make_env(args): def make_env(args):
# For making ground plane transparent, change rgba to (0, 0, 0, 0) in local_dm_control_suite/{domain_name}.xml,
# else change to (0.5, 0.5, 0.5, 1.0) for default ground plane color
# https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-geom
env = dmc2gym.make( env = dmc2gym.make(
domain_name=args.domain_name, domain_name=args.domain_name,
task_name=args.task_name, task_name=args.task_name,

View File

@ -1,3 +1,9 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import imageio import imageio
import os import os
import numpy as np import numpy as np