Adding new background videos for each episode
This commit is contained in:
parent
43f862ee6d
commit
8464503dd8
@ -131,12 +131,16 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
|||||||
self.p_mu = nn.Sequential(
|
self.p_mu = nn.Sequential(
|
||||||
nn.Linear(x_dim, hidden_size//2),
|
nn.Linear(x_dim, hidden_size//2),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_size//2, hidden_size//2),
|
||||||
|
nn.ReLU(),
|
||||||
nn.Linear(hidden_size//2, y_dim)
|
nn.Linear(hidden_size//2, y_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.p_logvar = nn.Sequential(
|
self.p_logvar = nn.Sequential(
|
||||||
nn.Linear(x_dim, hidden_size//2),
|
nn.Linear(x_dim, hidden_size//2),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_size//2, hidden_size//2),
|
||||||
|
nn.ReLU(),
|
||||||
nn.Linear(hidden_size//2, y_dim),
|
nn.Linear(hidden_size//2, y_dim),
|
||||||
nn.Tanh()
|
nn.Tanh()
|
||||||
)
|
)
|
||||||
@ -183,4 +187,3 @@ if __name__ == "__main__":
|
|||||||
print(x_enc.shape)
|
print(x_enc.shape)
|
||||||
print(y_enc.shape)
|
print(y_enc.shape)
|
||||||
print(club.learning_loss(x_enc, y_enc))
|
print(club.learning_loss(x_enc, y_enc))
|
||||||
|
|
||||||
|
13
DPI/train.py
13
DPI/train.py
@ -169,20 +169,26 @@ class DPI:
|
|||||||
done = False
|
done = False
|
||||||
|
|
||||||
for episode_count in range(episodes):
|
for episode_count in range(episodes):
|
||||||
|
video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
|
||||||
|
video.init(enabled=True)
|
||||||
for i in range(self.args.episode_length):
|
for i in range(self.args.episode_length):
|
||||||
action = self.env.action_space.sample()
|
action = self.env.action_space.sample()
|
||||||
next_obs, _, done, _ = self.env.step(action)
|
next_obs, _, done, _ = self.env.step(action)
|
||||||
|
|
||||||
self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
|
self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
|
||||||
|
|
||||||
|
if args.save_video:
|
||||||
|
video.record(self.env)
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
done=False
|
done=False
|
||||||
else:
|
else:
|
||||||
obs = next_obs
|
obs = next_obs
|
||||||
|
video.save('%d.mp4' % episode_count)
|
||||||
print("Collected {} random episodes".format(episode_count+1))
|
print("Collected {} random episodes".format(episode_count+1))
|
||||||
#if args.save_video:
|
#if args.save_video:
|
||||||
# video.record(env)
|
# video.record(self.env)
|
||||||
#video.save('%d.mp4' % step)
|
#video.save('%d.mp4' % step)
|
||||||
#video.close()
|
#video.close()
|
||||||
|
|
||||||
@ -195,13 +201,16 @@ class DPI:
|
|||||||
next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()
|
next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()
|
||||||
|
|
||||||
# Train encoder
|
# Train encoder
|
||||||
|
previous_information_loss = 0
|
||||||
for i in range(self.args.episode_length):
|
for i in range(self.args.episode_length):
|
||||||
# Encode observations and next_observations
|
# Encode observations and next_observations
|
||||||
self.features = self.obs_encoder(observations[i]) # (N,128)
|
self.features = self.obs_encoder(observations[i]) # (N,128)
|
||||||
self.next_features = self.obs_encoder(next_observations[i]) # (N,128)
|
self.next_features = self.obs_encoder(next_observations[i]) # (N,128)
|
||||||
|
|
||||||
# Calculate upper bound loss
|
# Calculate upper bound loss
|
||||||
past_loss = self.upper_bound_minimization(self.features, self.next_features)
|
past_loss = previous_information_loss + self.upper_bound_minimization(self.features, self.next_features)
|
||||||
|
previous_information_loss = past_loss
|
||||||
|
print("past_loss: ", past_loss)
|
||||||
|
|
||||||
def upper_bound_minimization(self, features, next_features):
|
def upper_bound_minimization(self, features, next_features):
|
||||||
club_sample = CLUBSample(self.args.state_size,
|
club_sample = CLUBSample(self.args.state_size,
|
||||||
|
@ -1,3 +1,9 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -166,6 +172,9 @@ class ReplayBuffer:
|
|||||||
|
|
||||||
|
|
||||||
def make_env(args):
|
def make_env(args):
|
||||||
|
# For making ground plane transparent, change rgba to (0, 0, 0, 0) in local_dm_control_suite/{domain_name}.xml,
|
||||||
|
# else change to (0.5, 0.5, 0.5, 1.0) for default ground plane color
|
||||||
|
# https://mujoco.readthedocs.io/en/stable/XMLreference.html#body-geom
|
||||||
env = dmc2gym.make(
|
env = dmc2gym.make(
|
||||||
domain_name=args.domain_name,
|
domain_name=args.domain_name,
|
||||||
task_name=args.task_name,
|
task_name=args.task_name,
|
||||||
|
@ -1,3 +1,9 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import imageio
|
import imageio
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
Loading…
Reference in New Issue
Block a user