105 lines
3.6 KiB
Python
105 lines
3.6 KiB
Python
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||
|
# All rights reserved.
|
||
|
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
import random
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
|
||
|
class DeterministicTransitionModel(nn.Module):
|
||
|
|
||
|
def __init__(self, encoder_feature_dim, action_shape, layer_width):
|
||
|
super().__init__()
|
||
|
self.fc = nn. Linear(encoder_feature_dim + action_shape[0], layer_width)
|
||
|
self.ln = nn.LayerNorm(layer_width)
|
||
|
self.fc_mu = nn.Linear(layer_width, encoder_feature_dim)
|
||
|
print("Deterministic transition model chosen.")
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.fc(x)
|
||
|
x = self.ln(x)
|
||
|
x = torch.relu(x)
|
||
|
|
||
|
mu = self.fc_mu(x)
|
||
|
sigma = None
|
||
|
return mu, sigma
|
||
|
|
||
|
def sample_prediction(self, x):
|
||
|
mu, sigma = self(x)
|
||
|
return mu
|
||
|
|
||
|
|
||
|
class ProbabilisticTransitionModel(nn.Module):
|
||
|
|
||
|
def __init__(self, encoder_feature_dim, action_shape, layer_width, announce=True, max_sigma=1e1, min_sigma=1e-4):
|
||
|
super().__init__()
|
||
|
self.fc = nn. Linear(encoder_feature_dim + action_shape[0], layer_width)
|
||
|
self.ln = nn.LayerNorm(layer_width)
|
||
|
self.fc_mu = nn.Linear(layer_width, encoder_feature_dim)
|
||
|
self.fc_sigma = nn.Linear(layer_width, encoder_feature_dim)
|
||
|
|
||
|
self.max_sigma = max_sigma
|
||
|
self.min_sigma = min_sigma
|
||
|
assert(self.max_sigma >= self.min_sigma)
|
||
|
if announce:
|
||
|
print("Probabilistic transition model chosen.")
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.fc(x)
|
||
|
x = self.ln(x)
|
||
|
x = torch.relu(x)
|
||
|
|
||
|
mu = self.fc_mu(x)
|
||
|
sigma = torch.sigmoid(self.fc_sigma(x)) # range (0, 1.)
|
||
|
sigma = self.min_sigma + (self.max_sigma - self.min_sigma) * sigma # scaled range (min_sigma, max_sigma)
|
||
|
return mu, sigma
|
||
|
|
||
|
def sample_prediction(self, x):
|
||
|
mu, sigma = self(x)
|
||
|
eps = torch.randn_like(sigma)
|
||
|
return mu + sigma * eps
|
||
|
|
||
|
|
||
|
class EnsembleOfProbabilisticTransitionModels(object):
|
||
|
|
||
|
def __init__(self, encoder_feature_dim, action_shape, layer_width, ensemble_size=5):
|
||
|
self.models = [ProbabilisticTransitionModel(encoder_feature_dim, action_shape, layer_width, announce=False)
|
||
|
for _ in range(ensemble_size)]
|
||
|
print("Ensemble of probabilistic transition models chosen.")
|
||
|
|
||
|
def __call__(self, x):
|
||
|
mu_sigma_list = [model.forward(x) for model in self.models]
|
||
|
mus, sigmas = zip(*mu_sigma_list)
|
||
|
mus, sigmas = torch.stack(mus), torch.stack(sigmas)
|
||
|
return mus, sigmas
|
||
|
|
||
|
def sample_prediction(self, x):
|
||
|
model = random.choice(self.models)
|
||
|
return model.sample_prediction(x)
|
||
|
|
||
|
def to(self, device):
|
||
|
for model in self.models:
|
||
|
model.to(device)
|
||
|
return self
|
||
|
|
||
|
def parameters(self):
|
||
|
list_of_parameters = [list(model.parameters()) for model in self.models]
|
||
|
parameters = [p for ps in list_of_parameters for p in ps]
|
||
|
return parameters
|
||
|
|
||
|
|
||
|
_AVAILABLE_TRANSITION_MODELS = {'': DeterministicTransitionModel,
|
||
|
'deterministic': DeterministicTransitionModel,
|
||
|
'probabilistic': ProbabilisticTransitionModel,
|
||
|
'ensemble': EnsembleOfProbabilisticTransitionModels}
|
||
|
|
||
|
|
||
|
def make_transition_model(transition_model_type, encoder_feature_dim, action_shape, layer_width=512):
|
||
|
assert transition_model_type in _AVAILABLE_TRANSITION_MODELS
|
||
|
return _AVAILABLE_TRANSITION_MODELS[transition_model_type](
|
||
|
encoder_feature_dim, action_shape, layer_width
|
||
|
)
|