From 99558ce92bc8b67b208ec9c1273c78784a52fda3 Mon Sep 17 00:00:00 2001 From: Vedant Dave Date: Mon, 22 May 2023 14:08:02 +0200 Subject: [PATCH] Adding encoder --- encoder.py | 196 ++++++++++++++++++++++++----------------------------- 1 file changed, 88 insertions(+), 108 deletions(-) diff --git a/encoder.py b/encoder.py index f6bc0bb..be32fd6 100644 --- a/encoder.py +++ b/encoder.py @@ -1,9 +1,8 @@ import torch import torch.nn as nn -import torch.nn.functional as F -def tie_weights(src, trg): +def tie_weights(src, trg): assert type(src) == type(trg) trg.weight = src.weight trg.bias = src.bias @@ -11,6 +10,85 @@ def tie_weights(src, trg): OUT_DIM = {2: 39, 4: 35, 6: 31} +''' +class PixelEncoder(nn.Module): + """Convolutional encoder of pixels observations.""" + def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32): + super().__init__() + + assert len(obs_shape) == 3 + + self.feature_dim = feature_dim + self.num_layers = num_layers + + self.convs = nn.ModuleList( + [nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)] + ) + for i in range(num_layers - 1): + self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1)) + + out_dim = OUT_DIM[num_layers] + self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim) + self.ln = nn.LayerNorm(self.feature_dim) + + self.outputs = dict() + + def reparameterize(self, mu, logstd): + std = torch.exp(logstd) + eps = torch.randn_like(std) + return mu + eps * std + + def forward_conv(self, obs): + obs = obs / 255. + self.outputs['obs'] = obs + + conv = torch.relu(self.convs[0](obs)) + self.outputs['conv1'] = conv + + for i in range(1, self.num_layers): + conv = torch.relu(self.convs[i](conv)) + self.outputs['conv%s' % (i + 1)] = conv + + h = conv.view(conv.size(0), -1) + return h + + def forward(self, obs, detach=False): + h = self.forward_conv(obs) + + if detach: + h = h.detach() + + h_fc = self.fc(h) + self.outputs['fc'] = h_fc + + h_norm = self.ln(h_fc) + self.outputs['ln'] = h_norm + + out = torch.tanh(h_norm) + self.outputs['tanh'] = out + + return out + + def copy_conv_weights_from(self, source): + """Tie convolutional layers""" + # only tie conv layers + for i in range(self.num_layers): + tie_weights(src=source.convs[i], trg=self.convs[i]) + + def log(self, L, step, log_freq): + if step % log_freq != 0: + return + + for k, v in self.outputs.items(): + L.log_histogram('train_encoder/%s_hist' % k, v, step) + if len(v.shape) > 2: + L.log_image('train_encoder/%s_img' % k, v[0], step) + + for i in range(self.num_layers): + L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step) + L.log_param('train_encoder/fc', self.fc, step) + L.log_param('train_encoder/ln', self.ln, step) +''' class PixelEncoder(nn.Module): """Convolutional encoder of pixels observations.""" @@ -64,19 +142,19 @@ class PixelEncoder(nn.Module): h_norm = self.ln(h_fc) self.outputs['ln'] = h_norm - - h_tan = torch.tanh(h_norm) - mu, logstd = torch.chunk(h_tan, 2, dim=-1) + #out = torch.tanh(h_norm) + + mu, logstd = torch.chunk(h_norm, 2, dim=-1) + logstd = torch.tanh(logstd) self.outputs['mu'] = mu self.outputs['logstd'] = logstd - - std = torch.tanh(h_norm) - self.outputs['std'] = std + self.outputs['std'] = logstd.exp() out = self.reparameterize(mu, logstd) - return out, mu, logstd - + self.outputs['tanh'] = out + return out + def copy_conv_weights_from(self, source): """Tie convolutional layers""" # only tie conv layers @@ -97,7 +175,6 @@ class PixelEncoder(nn.Module): L.log_param('train_encoder/fc', self.fc, step) L.log_param('train_encoder/ln', self.ln, step) - class IdentityEncoder(nn.Module): def __init__(self, obs_shape, feature_dim, num_layers, num_filters): super().__init__() @@ -115,103 +192,6 @@ class IdentityEncoder(nn.Module): pass -class TransitionModel(nn.Module): - def __init__(self, state_size, hidden_size, action_size, history_size): - super().__init__() - - self.state_size = state_size - self.hidden_size = hidden_size - self.action_size = action_size - self.history_size = history_size - self.act_fn = nn.ELU() - - self.fc_state_action = nn.Linear(state_size + action_size, hidden_size) - self.history_cell = nn.GRUCell(hidden_size, history_size) - self.fc_state_mu = nn.Linear(history_size + hidden_size, state_size) - self.fc_state_sigma = nn.Linear(history_size + hidden_size, state_size) - - self.batch_norm = nn.BatchNorm1d(hidden_size) - self.batch_norm2 = nn.BatchNorm1d(state_size) - - self.min_sigma = 1e-4 - self.max_sigma = 1e0 - - def init_states(self, batch_size, device): - self.prev_state = torch.zeros(batch_size, self.state_size).to(device) - self.prev_action = torch.zeros(batch_size, self.action_size).to(device) - self.prev_history = torch.zeros(batch_size, self.history_size).to(device) - - def get_dist(self, mean, std): - distribution = torch.distributions.Normal(mean, std) - distribution = torch.distributions.independent.Independent(distribution, 1) - return distribution - - def stack_states(self, states, dim=0): - s = dict( - mean = torch.stack([state['mean'] for state in states], dim=dim), - std = torch.stack([state['std'] for state in states], dim=dim), - sample = torch.stack([state['sample'] for state in states], dim=dim), - history = torch.stack([state['history'] for state in states], dim=dim),) - if 'distribution' in states: - dist = dict(distribution = [state['distribution'] for state in states]) - s.update(dist) - return s - - def seq_to_batch(self, state, name): - return dict( - sample = torch.reshape(state[name], (state[name].shape[0]* state[name].shape[1], *state[name].shape[2:]))) - - def transition_step(self, prev_state, prev_action, prev_hist, prev_not_done): - prev_state = prev_state.detach() * prev_not_done - prev_hist = prev_hist * prev_not_done - - state_action_enc = self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1)) - state_action_enc = self.act_fn(self.batch_norm(state_action_enc)) - - current_hist = self.history_cell(state_action_enc, prev_hist) - state_mu = self.act_fn(self.fc_state_mu(torch.cat([state_action_enc, prev_hist], dim=-1))) - state_sigma = F.softplus(self.fc_state_sigma(torch.cat([state_action_enc, prev_hist], dim=-1))) - sample_state = state_mu + torch.randn_like(state_mu) * state_sigma - - state_enc = {"mean": state_mu, "std": state_sigma, "sample": sample_state, "history": current_hist} - return state_enc - - def observe_step(self, prev_state, prev_action, prev_history): - state_action_enc = self.act_fn(self.batch_norm(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1)))) - current_history = self.history_cell(state_action_enc, prev_history) - state_mu = self.act_fn(self.batch_norm2(self.fc_state_mu(torch.cat([state_action_enc, prev_history], dim=-1)))) - state_sigma = F.softplus(self.fc_state_sigma(torch.cat([state_action_enc, prev_history], dim=-1))) - - sample_state = state_mu + torch.randn_like(state_mu) * state_sigma - state_enc = {"mean": state_mu, "std": state_sigma, "sample": sample_state, "history": current_history} - return state_enc - - def observe_rollout(self, rollout_states, rollout_actions, init_history, nonterms): - observed_rollout = [] - for i in range(rollout_states.shape[0]): - rollout_states_ = rollout_states[i] - rollout_actions_ = rollout_actions[i] - init_history_ = nonterms[i] * init_history - state_enc = self.observe_step(rollout_states_, rollout_actions_, init_history_) - init_history = state_enc["history"] - observed_rollout.append(state_enc) - observed_rollout = self.stack_states(observed_rollout, dim=0) - return observed_rollout - - def reparemeterize(self, mean, std): - eps = torch.randn_like(mean) - return mean + eps * std - - -def club_loss(x_samples, x_mu, x_logvar, y_samples): - sample_size = x_samples.shape[0] - random_index = torch.randperm(sample_size).long() - - positive = -(x_mu - y_samples)**2 / x_logvar.exp() - negative = - (x_mu - y_samples[random_index])**2 / x_logvar.exp() - upper_bound = (positive.sum(dim = -1) - negative.sum(dim = -1)).mean() - return upper_bound/2. - _AVAILABLE_ENCODERS = {'pixel': PixelEncoder, 'identity': IdentityEncoder}