Adding new background videos for each episode
This commit is contained in:
parent
43f862ee6d
commit
8464503dd8
@ -131,12 +131,16 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
||||
self.p_mu = nn.Sequential(
|
||||
nn.Linear(x_dim, hidden_size//2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size//2, hidden_size//2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size//2, y_dim)
|
||||
)
|
||||
|
||||
self.p_logvar = nn.Sequential(
|
||||
nn.Linear(x_dim, hidden_size//2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size//2, hidden_size//2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size//2, y_dim),
|
||||
nn.Tanh()
|
||||
)
|
||||
@ -182,5 +186,4 @@ if __name__ == "__main__":
|
||||
y_enc = encoder(y)
|
||||
print(x_enc.shape)
|
||||
print(y_enc.shape)
|
||||
print(club.learning_loss(x_enc, y_enc))
|
||||
|
||||
print(club.learning_loss(x_enc, y_enc))
|
15
DPI/train.py
15
DPI/train.py
@ -169,20 +169,26 @@ class DPI:
|
||||
done = False
|
||||
|
||||
for episode_count in range(episodes):
|
||||
video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
|
||||
video.init(enabled=True)
|
||||
for i in range(self.args.episode_length):
|
||||
action = self.env.action_space.sample()
|
||||
next_obs, _, done, _ = self.env.step(action)
|
||||
|
||||
self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
|
||||
|
||||
|
||||
if args.save_video:
|
||||
video.record(self.env)
|
||||
|
||||
if done:
|
||||
obs = self.env.reset()
|
||||
done=False
|
||||
else:
|
||||
obs = next_obs
|
||||
video.save('%d.mp4' % episode_count)
|
||||
print("Collected {} random episodes".format(episode_count+1))
|
||||
#if args.save_video:
|
||||
# video.record(env)
|
||||
# video.record(self.env)
|
||||
#video.save('%d.mp4' % step)
|
||||
#video.close()
|
||||
|
||||
@ -195,13 +201,16 @@ class DPI:
|
||||
next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()
|
||||
|
||||
# Train encoder
|
||||
previous_information_loss = 0
|
||||
for i in range(self.args.episode_length):
|
||||
# Encode observations and next_observations
|
||||
self.features = self.obs_encoder(observations[i]) # (N,128)
|
||||
self.next_features = self.obs_encoder(next_observations[i]) # (N,128)
|
||||
|
||||
# Calculate upper bound loss
|
||||
past_loss = self.upper_bound_minimization(self.features, self.next_features)
|
||||
past_loss = previous_information_loss + self.upper_bound_minimization(self.features, self.next_features)
|
||||
previous_information_loss = past_loss
|
||||
print("past_loss: ", past_loss)
|
||||
|
||||
def upper_bound_minimization(self, features, next_features):
|
||||
club_sample = CLUBSample(self.args.state_size,
|
||||
|
@ -1,3 +1,9 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
@ -166,6 +172,9 @@ class ReplayBuffer:
|
||||
|
||||
|
||||
def make_env(args):
|
||||
# For making ground plane transparent, change rgba to (0, 0, 0, 0) in local_dm_control_suite/{domain_name}.xml,
|
||||
# else change to (0.5, 0.5, 0.5, 1.0) for default ground plane color
|
||||
# https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-geom
|
||||
env = dmc2gym.make(
|
||||
domain_name=args.domain_name,
|
||||
task_name=args.task_name,
|
||||
|
@ -1,3 +1,9 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import imageio
|
||||
import os
|
||||
import numpy as np
|
||||
@ -51,4 +57,4 @@ class VideoRecorder(object):
|
||||
def save(self, file_name):
|
||||
if self.enabled:
|
||||
path = os.path.join(self.dir_name, file_name)
|
||||
imageio.mimsave(path, self.frames, fps=self.fps)
|
||||
imageio.mimsave(path, self.frames, fps=self.fps)
|
Loading…
Reference in New Issue
Block a user