Compare commits
No commits in common. "11f00ad695cab2fd3208c4de8ea0f5c5588530ec" and "a351134f08bc82f387640025f4eac03929c7e160" have entirely different histories.
11f00ad695
...
a351134f08
@ -34,23 +34,13 @@ class ObservationEncoder(nn.Module):
|
|||||||
std = torch.clamp(std, min=0.0, max=1e5)
|
std = torch.clamp(std, min=0.0, max=1e5)
|
||||||
|
|
||||||
# Normal Distribution
|
# Normal Distribution
|
||||||
dist = self.get_dist(mean, std)
|
|
||||||
|
|
||||||
# Sampling via reparameterization Trick
|
|
||||||
x = self.reparameterize(mean, std)
|
x = self.reparameterize(mean, std)
|
||||||
|
return x
|
||||||
encoded_output = {"sample": x, "distribution": dist}
|
|
||||||
return encoded_output
|
|
||||||
|
|
||||||
def reparameterize(self, mu, std):
|
def reparameterize(self, mu, std):
|
||||||
eps = torch.randn_like(std)
|
eps = torch.randn_like(std)
|
||||||
return mu + eps * std
|
return mu + eps * std
|
||||||
|
|
||||||
def get_dist(self, mean, std):
|
|
||||||
distribution = torch.distributions.Normal(mean, std)
|
|
||||||
distribution = torch.distributions.independent.Independent(distribution, 1)
|
|
||||||
return distribution
|
|
||||||
|
|
||||||
|
|
||||||
class ObservationDecoder(nn.Module):
|
class ObservationDecoder(nn.Module):
|
||||||
def __init__(self, state_size, output_shape):
|
def __init__(self, state_size, output_shape):
|
||||||
@ -124,12 +114,8 @@ class TransitionModel(nn.Module):
|
|||||||
state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1)
|
state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1)
|
||||||
state_prior_std = F.softplus(state_prior_std)
|
state_prior_std = F.softplus(state_prior_std)
|
||||||
|
|
||||||
# Normal Distribution
|
|
||||||
state_prior_dist = self.get_dist(state_prior_mean, state_prior_std)
|
|
||||||
|
|
||||||
# Sampling via reparameterization Trick
|
|
||||||
sample_state_prior = self.reparemeterize(state_prior_mean, state_prior_std)
|
sample_state_prior = self.reparemeterize(state_prior_mean, state_prior_std)
|
||||||
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist}
|
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history}
|
||||||
return prior
|
return prior
|
||||||
|
|
||||||
def reparemeterize(self, mean, std):
|
def reparemeterize(self, mean, std):
|
||||||
@ -168,4 +154,15 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
|
|||||||
|
|
||||||
def forward(self, x_samples, y_samples):
|
def forward(self, x_samples, y_samples):
|
||||||
mu, logvar = self.get_mu_logvar(x_samples)
|
mu, logvar = self.get_mu_logvar(x_samples)
|
||||||
|
|
||||||
|
sample_size = x_samples.shape[0]
|
||||||
|
#random_index = torch.randint(sample_size, (sample_size,)).long()
|
||||||
|
random_index = torch.randperm(sample_size).long()
|
||||||
|
|
||||||
|
positive = - (mu - y_samples)**2 / logvar.exp()
|
||||||
|
negative = - (mu - y_samples[random_index])**2 / logvar.exp()
|
||||||
|
upper_bound = (positive.sum(dim = -1) - negative.sum(dim = -1)).mean()
|
||||||
|
return upper_bound/2.
|
||||||
|
|
||||||
|
def learning_loss(self, x_samples, y_samples):
|
||||||
return - self.loglikeli(x_samples, y_samples)
|
return - self.loglikeli(x_samples, y_samples)
|
70
DPI/train.py
70
DPI/train.py
@ -7,7 +7,6 @@ import time
|
|||||||
import json
|
import json
|
||||||
import dmc2gym
|
import dmc2gym
|
||||||
|
|
||||||
import tqdm
|
|
||||||
import wandb
|
import wandb
|
||||||
import utils
|
import utils
|
||||||
from utils import ReplayBuffer, make_env, save_image
|
from utils import ReplayBuffer, make_env, save_image
|
||||||
@ -34,10 +33,10 @@ def parse_args():
|
|||||||
parser.add_argument('--resource_files', type=str)
|
parser.add_argument('--resource_files', type=str)
|
||||||
parser.add_argument('--eval_resource_files', type=str)
|
parser.add_argument('--eval_resource_files', type=str)
|
||||||
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
|
||||||
parser.add_argument('--total_frames', default=1000, type=int) # 10000
|
parser.add_argument('--total_frames', default=10000, type=int)
|
||||||
parser.add_argument('--high_noise', action='store_true')
|
parser.add_argument('--high_noise', action='store_true')
|
||||||
# replay buffer
|
# replay buffer
|
||||||
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #50000
|
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #100000
|
||||||
parser.add_argument('--episode_length', default=50, type=int)
|
parser.add_argument('--episode_length', default=50, type=int)
|
||||||
# train
|
# train
|
||||||
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
|
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
|
||||||
@ -131,6 +130,10 @@ class DPI:
|
|||||||
self.model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model'))
|
self.model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model'))
|
||||||
self.buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer'))
|
self.buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer'))
|
||||||
|
|
||||||
|
# create video recorder
|
||||||
|
#video = VideoRecorder(video_dir if args.save_video else None, resource_files=args.resource_files)
|
||||||
|
#video.init(enabled=True)
|
||||||
|
|
||||||
# create models
|
# create models
|
||||||
self.build_models(use_saved=False, saved_model_dir=self.model_dir)
|
self.build_models(use_saved=False, saved_model_dir=self.model_dir)
|
||||||
|
|
||||||
@ -171,24 +174,28 @@ class DPI:
|
|||||||
done = False
|
done = False
|
||||||
|
|
||||||
#video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
|
#video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
|
||||||
for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'):
|
for episode_count in range(episodes):
|
||||||
#self.env.video.init(enabled=True)
|
self.env.video.init(enabled=True)
|
||||||
for i in range(self.args.episode_length):
|
for i in range(self.args.episode_length):
|
||||||
action = self.env.action_space.sample()
|
action = self.env.action_space.sample()
|
||||||
next_obs, _, done, _ = self.env.step(action)
|
next_obs, _, done, _ = self.env.step(action)
|
||||||
|
|
||||||
self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
|
self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
|
||||||
|
|
||||||
#if args.save_video:
|
if args.save_video:
|
||||||
# self.env.video.record(self.env)
|
self.env.video.record(self.env)
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
done=False
|
done=False
|
||||||
else:
|
else:
|
||||||
obs = next_obs
|
obs = next_obs
|
||||||
#self.env.video.save('%d.mp4' % episode_count)
|
self.env.video.save('%d.mp4' % episode_count)
|
||||||
print("Collected {} random episodes".format(episode_count+1))
|
print("Collected {} random episodes".format(episode_count+1))
|
||||||
|
#if args.save_video:
|
||||||
|
# video.record(self.env)
|
||||||
|
#video.save('%d.mp4' % step)
|
||||||
|
#video.close()
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
# collect experience
|
# collect experience
|
||||||
@ -197,59 +204,26 @@ class DPI:
|
|||||||
# Group observations and next_observations by steps
|
# Group observations and next_observations by steps
|
||||||
observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float()
|
observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float()
|
||||||
next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()
|
next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()
|
||||||
actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()
|
|
||||||
|
|
||||||
# Initialize transition model states
|
|
||||||
self.transition_model.init_states(self.args.batch_size, device="cpu") # (N,128)
|
|
||||||
self.history = self.transition_model.prev_history # (N,128)
|
|
||||||
|
|
||||||
# Train encoder
|
# Train encoder
|
||||||
previous_information_loss = 0
|
previous_information_loss = 0
|
||||||
previous_encoder_loss = 0
|
|
||||||
for i in range(self.args.episode_length):
|
for i in range(self.args.episode_length):
|
||||||
# Encode observations and next_observations
|
# Encode observations and next_observations
|
||||||
self.states_dist = self.obs_encoder(observations[i])
|
self.features = self.obs_encoder(observations[i]) # (N,128)
|
||||||
self.next_states_dist = self.obs_encoder(next_observations[i])
|
self.next_features = self.obs_encoder(next_observations[i]) # (N,128)
|
||||||
|
|
||||||
# Sample states and next_states
|
|
||||||
self.states = self.states_dist["sample"] # (N,128)
|
|
||||||
self.next_states = self.next_states_dist["sample"] # (N,128)
|
|
||||||
self.actions = actions[i] # (N,6)
|
|
||||||
|
|
||||||
# Calculate upper bound loss
|
# Calculate upper bound loss
|
||||||
past_latent_loss = previous_information_loss + self._upper_bound_minimization(self.states, self.next_states)
|
past_loss = previous_information_loss + self.upper_bound_minimization(self.features, self.next_features)
|
||||||
|
previous_information_loss = past_loss
|
||||||
|
print("past_loss: ", past_loss)
|
||||||
|
|
||||||
# Calculate encoder loss
|
def upper_bound_minimization(self, features, next_features):
|
||||||
past_encoder_loss = previous_encoder_loss + self._past_encoder_loss(self.states, self.next_states,
|
|
||||||
self.states_dist, self.next_states_dist,
|
|
||||||
self.actions, self.history, i)
|
|
||||||
|
|
||||||
previous_information_loss = past_latent_loss
|
|
||||||
previous_encoder_loss = past_encoder_loss
|
|
||||||
|
|
||||||
def _upper_bound_minimization(self, states, next_states):
|
|
||||||
club_sample = CLUBSample(self.args.state_size,
|
club_sample = CLUBSample(self.args.state_size,
|
||||||
self.args.state_size,
|
self.args.state_size,
|
||||||
self.args.hidden_size)
|
self.args.hidden_size)
|
||||||
club_loss = club_sample(states, next_states)
|
club_loss = club_sample(features, next_features)
|
||||||
return club_loss
|
return club_loss
|
||||||
|
|
||||||
def _past_encoder_loss(self, states, next_states, states_dist, next_states_dist, actions, history, step):
|
|
||||||
# Imagine next state
|
|
||||||
if step == 0:
|
|
||||||
actions = torch.zeros(self.args.batch_size, self.env.action_space.shape[0]).float() # Zero action for first step
|
|
||||||
imagined_next_states = self.transition_model.imagine_step(states, actions, history)
|
|
||||||
self.history = imagined_next_states["history"]
|
|
||||||
else:
|
|
||||||
imagined_next_states = self.transition_model.imagine_step(states, actions, self.history) # (N,128)
|
|
||||||
|
|
||||||
# State Distribution
|
|
||||||
imagined_next_states_dist = imagined_next_states["distribution"]
|
|
||||||
|
|
||||||
# KL divergence loss
|
|
||||||
loss = torch.distributions.kl.kl_divergence(imagined_next_states_dist, next_states_dist["distribution"]).mean()
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
14
DPI/utils.py
14
DPI/utils.py
@ -156,16 +156,14 @@ class ReplayBuffer:
|
|||||||
obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l)
|
obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l)
|
||||||
return obs,acs,rews,terms
|
return obs,acs,rews,terms
|
||||||
|
|
||||||
def group_steps(self, buffer, variable, obs=True):
|
def group_steps(self, buffer, variable):
|
||||||
variable = getattr(buffer, variable)
|
variable = getattr(buffer, variable)
|
||||||
non_zero_indices = np.nonzero(buffer.episode_count)[0]
|
non_zero_indices = np.nonzero(buffer.episode_count)[0]
|
||||||
variable = variable[non_zero_indices]
|
variable = variable[non_zero_indices]
|
||||||
if obs:
|
|
||||||
variable = variable.reshape(self.args.episode_length, self.args.batch_size,
|
variable = variable.reshape(self.args.episode_length, self.args.batch_size,
|
||||||
self.args.frame_stack*self.args.channels,
|
self.args.frame_stack*self.args.channels,
|
||||||
self.args.image_size,self.args.image_size)
|
self.args.image_size,self.args.image_size)
|
||||||
else:
|
|
||||||
variable = variable.reshape(self.args.episode_length, self.args.batch_size,-1)
|
|
||||||
return variable
|
return variable
|
||||||
|
|
||||||
def transform_grouped_steps(self, variable):
|
def transform_grouped_steps(self, variable):
|
||||||
@ -229,7 +227,7 @@ class CorruptVideos:
|
|||||||
Check if a video file is corrupt.
|
Check if a video file is corrupt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dir_path (str): Path to the video file.
|
filepath (str): Path to the video file.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the video is corrupt, False otherwise.
|
bool: True if the video is corrupt, False otherwise.
|
||||||
|
Loading…
Reference in New Issue
Block a user