import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.distributions.normal import Normal class ObservationEncoder(nn.Module): def __init__(self, obs_shape, state_size, num_layers=4, num_filters=32, stride=None): super().__init__() assert len(obs_shape) == 3 self.state_size = state_size layers = [] for i in range(num_layers): input_channels = obs_shape[0] if i == 0 else output_channels output_channels = num_filters * (2 ** i) layers.append(nn.Conv2d(in_channels=input_channels, out_channels= output_channels, kernel_size=4, stride=2)) layers.append(nn.LeakyReLU()) self.convs = nn.Sequential(*layers) self.fc = nn.Linear(256 * 3 * 3, 2 * state_size) def forward(self, x): x = self.convs(x) x = x.view(x.size(0), -1) x = self.fc(x) # Mean and standard deviation mean, std = torch.chunk(x, 2, dim=-1) std = F.softplus(std) std = torch.clamp(std, min=0.0, max=1e5) # Normal Distribution dist = self.get_dist(mean, std) # Sampling via reparameterization Trick x = self.reparameterize(mean, std) encoded_output = {"sample": x, "distribution": dist} return encoded_output def reparameterize(self, mu, std): eps = torch.randn_like(std) return mu + eps * std def get_dist(self, mean, std): distribution = torch.distributions.Normal(mean, std) distribution = torch.distributions.independent.Independent(distribution, 1) return distribution class ObservationDecoder(nn.Module): def __init__(self, state_size, output_shape): super().__init__() self.state_size = state_size self.output_shape = output_shape self.input_size = 256 * 3 * 3 self.in_channels = [self.input_size, 256, 128, 64] self.out_channels = [256, 128, 64, 3] if output_shape[1] == 84: self.kernels = [5, 7, 5, 6] self.output_padding = [1, 1, 1, 0] elif output_shape[1] == 64: self.kernels = [5, 5, 6, 6] self.output_padding = [0, 0, 0, 0] self.dense = nn.Linear(state_size, self.input_size) layers = [] for i in range(len(self.kernels)): layers.append(nn.ConvTranspose2d(in_channels=self.in_channels[i], out_channels=self.out_channels[i], kernel_size=self.kernels[i], stride=2, output_padding=self.output_padding[i])) if i!=len(self.kernels)-1: layers.append(nn.ReLU()) self.convtranspose = nn.Sequential(*layers) def forward(self, features): out_batch_shape = features.shape[:-1] out = self.dense(features) out = torch.reshape(out, [-1, self.input_size, 1, 1]) out = self.convtranspose(out) mean = torch.reshape(out, (*out_batch_shape, *self.output_shape)) out_dist = torch.distributions.independent.Independent(torch.distributions.Normal(mean, 1), len(self.output_shape)) return out_dist class Actor(nn.Module): def __init__(self, state_size, hidden_size, action_size, num_layers=5): super().__init__() self.state_size = state_size self.hidden_size = hidden_size self.action_size = action_size self.num_layers = num_layers self._min_std=torch.Tensor([1e-4])[0] self._init_std=torch.Tensor([5])[0] self._mean_scale=torch.Tensor([5])[0] layers = [] for i in range(self.num_layers): input_channels = state_size if i == 0 else self.hidden_size output_channels = self.hidden_size if i!= self.num_layers-1 else 2*action_size layers.append(nn.Linear(input_channels, output_channels)) layers.append(nn.ReLU()) self.action_model = nn.Sequential(*layers) def get_dist(self, mean, std): distribution = torch.distributions.Normal(mean, std) distribution = torch.distributions.transformed_distribution.TransformedDistribution(distribution, TanhBijector()) distribution = torch.distributions.independent.Independent(distribution, 1) return distribution def forward(self, features): out = self.action_model(features) mean, std = torch.chunk(out, 2, dim=-1) raw_init_std = torch.log(torch.exp(self._init_std) - 1) action_mean = self._mean_scale * torch.tanh(mean / self._mean_scale) action_std = F.softplus(std + raw_init_std) + self._min_std dist = self.get_dist(action_mean, action_std) sample = dist.rsample() return sample class ValueModel(nn.Module): def __init__(self, state_size, hidden_size, num_layers=4): super().__init__() self.state_size = state_size self.hidden_size = hidden_size self.num_layers = num_layers layers = [] for i in range(self.num_layers): input_channels = state_size if i == 0 else self.hidden_size output_channels = self.hidden_size if i!= self.num_layers-1 else 1 layers.append(nn.Linear(input_channels, output_channels)) layers.append(nn.ReLU()) self.value_model = nn.Sequential(*layers) def forward(self, state): value = self.value_model(state) value_dist = torch.distributions.independent.Independent(torch.distributions.Normal(value, 1), 1) return value_dist class RewardModel(nn.Module): def __init__(self, state_size, hidden_size): super().__init__() self.reward_model = nn.Sequential( nn.Linear(state_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 1) ) def forward(self, state): reward = self.reward_model(state).squeeze(dim=1) return reward class TransitionModel(nn.Module): def __init__(self, state_size, hidden_size, action_size, history_size): super().__init__() self.state_size = state_size self.hidden_size = hidden_size self.action_size = action_size self.history_size = history_size self.act_fn = nn.ReLU() self.fc_state_action = nn.Linear(state_size + action_size, hidden_size) self.history_cell = nn.GRUCell(hidden_size + history_size, history_size) self.fc_state_prior = nn.Linear(history_size + state_size + action_size, 2 * state_size) self.fc_state_posterior = nn.Linear(history_size + state_size + action_size, 2 * state_size) def init_states(self, batch_size, device): self.prev_state = torch.zeros(batch_size, self.state_size).to(device) self.prev_action = torch.zeros(batch_size, self.action_size).to(device) self.prev_history = torch.zeros(batch_size, self.history_size).to(device) def get_dist(self, mean, std): distribution = torch.distributions.Normal(mean, std) distribution = torch.distributions.independent.Independent(distribution, 1) return distribution def imagine_step(self, prev_state, prev_action, prev_history): state_action = self.act_fn(self.fc_state_action(torch.cat([prev_state, prev_action], dim=-1))) prev_hist = prev_history.detach() history = self.history_cell(torch.cat([state_action, prev_hist], dim=-1), prev_hist) state_prior = self.fc_state_prior(torch.cat([history, prev_state, prev_action], dim=-1)) state_prior_mean, state_prior_std = torch.chunk(state_prior, 2, dim=-1) state_prior_std = F.softplus(state_prior_std) # Normal Distribution state_prior_dist = self.get_dist(state_prior_mean, state_prior_std) # Sampling via reparameterization Trick sample_state_prior = self.reparemeterize(state_prior_mean, state_prior_std) prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist} return prior def stack_states(self, states, dim=0): s = dict( mean = torch.stack([state['mean'] for state in states], dim=dim), std = torch.stack([state['std'] for state in states], dim=dim), sample = torch.stack([state['sample'] for state in states], dim=dim), history = torch.stack([state['history'] for state in states], dim=dim),) dist = dict(distribution = [state['distribution'] for state in states]) s.update(dist) return s def imagine_rollout(self, state, action, history, horizon): imagined_priors = [] for i in range(horizon): prior = self.imagine_step(state, action, history) state = prior["sample"] history = prior["history"] imagined_priors.append(prior) imagined_priors = self.stack_states(imagined_priors, dim=0) return imagined_priors def reparemeterize(self, mean, std): eps = torch.randn_like(std) return mean + eps * std class TanhBijector(torch.distributions.Transform): def __init__(self): super().__init__() self.bijective = True self.domain = torch.distributions.constraints.real self.codomain = torch.distributions.constraints.interval(-1.0, 1.0) @property def sign(self): return 1. def _call(self, x): return torch.tanh(x) def atanh(self, x): return 0.5 * torch.log((1 + x) / (1 - x)) def _inverse(self, y: torch.Tensor): y = torch.where( (torch.abs(y) <= 1.), torch.clamp(y, -0.99999997, 0.99999997), y) y = self.atanh(y) return y def log_abs_det_jacobian(self, x, y): #return 2. * (np.log(2) - x - F.softplus(-2. * x)) return 2.0 * (torch.log(torch.tensor([2.0])) - x - F.softplus(-2.0 * x)) class ProjectionHead(nn.Module): def __init__(self, state_size, action_size, hidden_size): super(ProjectionHead, self).__init__() self.state_size = state_size self.action_size = action_size self.hidden_size = hidden_size self.projection_model = nn.Sequential( nn.Linear(state_size + action_size, hidden_size), nn.LayerNorm(hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.LayerNorm(hidden_size), ) def forward(self, state, action): x = torch.cat([state, action], dim=-1) x = self.projection_model(x) return x class ContrastiveHead(nn.Module): def __init__(self, hidden_size, temperature=1): super(ContrastiveHead, self).__init__() self.hidden_size = hidden_size self.temperature = temperature self.W = nn.Parameter(torch.rand(self.hidden_size, self.hidden_size)) def forward(self, z_a, z_pos): Wz = torch.matmul(self.W, z_pos.T) # (z_dim,B) logits = torch.matmul(z_a, Wz) # (B,B) logits = logits - torch.max(logits, 1)[0][:, None] logits = logits * self.temperature return logits class CLUBSample(nn.Module): # Sampled version of the CLUB estimator def __init__(self, last_states, current_states, negative_current_states, predicted_current_states): super(CLUBSample, self).__init__() self.last_states = last_states self.current_states = current_states self.negative_current_states = negative_current_states self.predicted_current_states = predicted_current_states def get_mu_var_samples(self, state_dict): dist = state_dict["distribution"] sample = dist.sample() # Use state_dict["sample"] if you want to use the same sample for all the losses mu = dist.mean var = dist.variance return mu, var, sample def loglikeli(self): _, _, pred_sample = self.get_mu_var_samples(self.predicted_current_states) mu_curr, var_curr, _ = self.get_mu_var_samples(self.current_states) logvar_curr = torch.log(var_curr) return (-(mu_curr - pred_sample)**2 /var_curr-logvar_curr).sum(dim=1).mean(dim=0) def forward(self): _, _, pred_sample = self.get_mu_var_samples(self.predicted_current_states) mu_curr, var_curr, _ = self.get_mu_var_samples(self.current_states) mu_neg, var_neg, _ = self.get_mu_var_samples(self.negative_current_states) pos = (-(mu_curr - pred_sample)**2 /var_curr).sum(dim=1).mean(dim=0) neg = (-(mu_neg - pred_sample)**2 /var_neg).sum(dim=1).mean(dim=0) upper_bound = pos - neg return upper_bound/2 def learning_loss(self): return - self.loglikeli() if "__name__ == __main__": pass