Compare commits

..

3 Commits

3 changed files with 76 additions and 45 deletions

View File

@ -34,12 +34,22 @@ class ObservationEncoder(nn.Module):
std = torch.clamp(std, min=0.0, max=1e5) std = torch.clamp(std, min=0.0, max=1e5)
# Normal Distribution # Normal Distribution
dist = self.get_dist(mean, std)
# Sampling via reparameterization Trick
x = self.reparameterize(mean, std) x = self.reparameterize(mean, std)
return x
encoded_output = {"sample": x, "distribution": dist}
return encoded_output
def reparameterize(self, mu, std): def reparameterize(self, mu, std):
eps = torch.randn_like(std) eps = torch.randn_like(std)
return mu + eps * std return mu + eps * std
def get_dist(self, mean, std):
distribution = torch.distributions.Normal(mean, std)
distribution = torch.distributions.independent.Independent(distribution, 1)
return distribution
class ObservationDecoder(nn.Module): class ObservationDecoder(nn.Module):
@ -114,8 +124,12 @@ class TransitionModel(nn.Module):
state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1) state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1)
state_prior_std = F.softplus(state_prior_std) state_prior_std = F.softplus(state_prior_std)
# Normal Distribution
state_prior_dist = self.get_dist(state_prior_mean, state_prior_std)
# Sampling via reparameterization Trick
sample_state_prior = self.reparemeterize(state_prior_mean, state_prior_std) sample_state_prior = self.reparemeterize(state_prior_mean, state_prior_std)
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history} prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist}
return prior return prior
def reparemeterize(self, mean, std): def reparemeterize(self, mean, std):
@ -154,15 +168,4 @@ class CLUBSample(nn.Module): # Sampled version of the CLUB estimator
def forward(self, x_samples, y_samples): def forward(self, x_samples, y_samples):
mu, logvar = self.get_mu_logvar(x_samples) mu, logvar = self.get_mu_logvar(x_samples)
sample_size = x_samples.shape[0]
#random_index = torch.randint(sample_size, (sample_size,)).long()
random_index = torch.randperm(sample_size).long()
positive = - (mu - y_samples)**2 / logvar.exp()
negative = - (mu - y_samples[random_index])**2 / logvar.exp()
upper_bound = (positive.sum(dim = -1) - negative.sum(dim = -1)).mean()
return upper_bound/2.
def learning_loss(self, x_samples, y_samples):
return - self.loglikeli(x_samples, y_samples) return - self.loglikeli(x_samples, y_samples)

View File

@ -7,6 +7,7 @@ import time
import json import json
import dmc2gym import dmc2gym
import tqdm
import wandb import wandb
import utils import utils
from utils import ReplayBuffer, make_env, save_image from utils import ReplayBuffer, make_env, save_image
@ -33,10 +34,10 @@ def parse_args():
parser.add_argument('--resource_files', type=str) parser.add_argument('--resource_files', type=str)
parser.add_argument('--eval_resource_files', type=str) parser.add_argument('--eval_resource_files', type=str)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none']) parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--total_frames', default=10000, type=int) parser.add_argument('--total_frames', default=1000, type=int) # 10000
parser.add_argument('--high_noise', action='store_true') parser.add_argument('--high_noise', action='store_true')
# replay buffer # replay buffer
parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #100000 parser.add_argument('--replay_buffer_capacity', default=50000, type=int) #50000
parser.add_argument('--episode_length', default=50, type=int) parser.add_argument('--episode_length', default=50, type=int)
# train # train
parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc']) parser.add_argument('--agent', default='dpi', type=str, choices=['baseline', 'bisim', 'deepmdp', 'db', 'dpi', 'rpc'])
@ -130,10 +131,6 @@ class DPI:
self.model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model')) self.model_dir = utils.make_dir(os.path.join(self.args.work_dir, 'model'))
self.buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer')) self.buffer_dir = utils.make_dir(os.path.join(self.args.work_dir, 'buffer'))
# create video recorder
#video = VideoRecorder(video_dir if args.save_video else None, resource_files=args.resource_files)
#video.init(enabled=True)
# create models # create models
self.build_models(use_saved=False, saved_model_dir=self.model_dir) self.build_models(use_saved=False, saved_model_dir=self.model_dir)
@ -174,28 +171,24 @@ class DPI:
done = False done = False
#video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files) #video = VideoRecorder(self.video_dir if args.save_video else None, resource_files=args.resource_files)
for episode_count in range(episodes): for episode_count in tqdm.tqdm(range(episodes), desc='Collecting episodes'):
self.env.video.init(enabled=True) #self.env.video.init(enabled=True)
for i in range(self.args.episode_length): for i in range(self.args.episode_length):
action = self.env.action_space.sample() action = self.env.action_space.sample()
next_obs, _, done, _ = self.env.step(action) next_obs, _, done, _ = self.env.step(action)
self.data_buffer.add(obs, action, next_obs, episode_count+1, done) self.data_buffer.add(obs, action, next_obs, episode_count+1, done)
if args.save_video: #if args.save_video:
self.env.video.record(self.env) # self.env.video.record(self.env)
if done: if done:
obs = self.env.reset() obs = self.env.reset()
done=False done=False
else: else:
obs = next_obs obs = next_obs
self.env.video.save('%d.mp4' % episode_count) #self.env.video.save('%d.mp4' % episode_count)
print("Collected {} random episodes".format(episode_count+1)) print("Collected {} random episodes".format(episode_count+1))
#if args.save_video:
# video.record(self.env)
#video.save('%d.mp4' % step)
#video.close()
def train(self): def train(self):
# collect experience # collect experience
@ -204,26 +197,59 @@ class DPI:
# Group observations and next_observations by steps # Group observations and next_observations by steps
observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float() observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"observations")).float()
next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float() next_observations = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"next_observations")).float()
actions = torch.Tensor(self.data_buffer.group_steps(self.data_buffer,"actions",obs=False)).float()
# Initialize transition model states
self.transition_model.init_states(self.args.batch_size, device="cpu") # (N,128)
self.history = self.transition_model.prev_history # (N,128)
# Train encoder # Train encoder
previous_information_loss = 0 previous_information_loss = 0
previous_encoder_loss = 0
for i in range(self.args.episode_length): for i in range(self.args.episode_length):
# Encode observations and next_observations # Encode observations and next_observations
self.features = self.obs_encoder(observations[i]) # (N,128) self.states_dist = self.obs_encoder(observations[i])
self.next_features = self.obs_encoder(next_observations[i]) # (N,128) self.next_states_dist = self.obs_encoder(next_observations[i])
# Sample states and next_states
self.states = self.states_dist["sample"] # (N,128)
self.next_states = self.next_states_dist["sample"] # (N,128)
self.actions = actions[i] # (N,6)
# Calculate upper bound loss # Calculate upper bound loss
past_loss = previous_information_loss + self.upper_bound_minimization(self.features, self.next_features) past_latent_loss = previous_information_loss + self._upper_bound_minimization(self.states, self.next_states)
previous_information_loss = past_loss
print("past_loss: ", past_loss) # Calculate encoder loss
past_encoder_loss = previous_encoder_loss + self._past_encoder_loss(self.states, self.next_states,
def upper_bound_minimization(self, features, next_features): self.states_dist, self.next_states_dist,
club_sample = CLUBSample(self.args.state_size, self.actions, self.history, i)
previous_information_loss = past_latent_loss
previous_encoder_loss = past_encoder_loss
def _upper_bound_minimization(self, states, next_states):
club_sample = CLUBSample(self.args.state_size,
self.args.state_size, self.args.state_size,
self.args.hidden_size) self.args.hidden_size)
club_loss = club_sample(features, next_features) club_loss = club_sample(states, next_states)
return club_loss return club_loss
def _past_encoder_loss(self, states, next_states, states_dist, next_states_dist, actions, history, step):
# Imagine next state
if step == 0:
actions = torch.zeros(self.args.batch_size, self.env.action_space.shape[0]).float() # Zero action for first step
imagined_next_states = self.transition_model.imagine_step(states, actions, history)
self.history = imagined_next_states["history"]
else:
imagined_next_states = self.transition_model.imagine_step(states, actions, self.history) # (N,128)
# State Distribution
imagined_next_states_dist = imagined_next_states["distribution"]
# KL divergence loss
loss = torch.distributions.kl.kl_divergence(imagined_next_states_dist, next_states_dist["distribution"]).mean()
return loss
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()

View File

@ -156,14 +156,16 @@ class ReplayBuffer:
obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l) obs,acs,rews,terms= self._retrieve_batch(np.asarray([self._sample_idx(l) for _ in range(n)]), n, l)
return obs,acs,rews,terms return obs,acs,rews,terms
def group_steps(self, buffer, variable): def group_steps(self, buffer, variable, obs=True):
variable = getattr(buffer, variable) variable = getattr(buffer, variable)
non_zero_indices = np.nonzero(buffer.episode_count)[0] non_zero_indices = np.nonzero(buffer.episode_count)[0]
variable = variable[non_zero_indices] variable = variable[non_zero_indices]
if obs:
variable = variable.reshape(self.args.episode_length, self.args.batch_size, variable = variable.reshape(self.args.episode_length, self.args.batch_size,
self.args.frame_stack*self.args.channels, self.args.frame_stack*self.args.channels,
self.args.image_size,self.args.image_size) self.args.image_size,self.args.image_size)
else:
variable = variable.reshape(self.args.episode_length, self.args.batch_size,-1)
return variable return variable
def transform_grouped_steps(self, variable): def transform_grouped_steps(self, variable):
@ -227,7 +229,7 @@ class CorruptVideos:
Check if a video file is corrupt. Check if a video file is corrupt.
Args: Args:
filepath (str): Path to the video file. dir_path (str): Path to the video file.
Returns: Returns:
bool: True if the video is corrupt, False otherwise. bool: True if the video is corrupt, False otherwise.