DBC/transition_model.py

105 lines
3.6 KiB
Python
Raw Normal View History

2020-10-12 22:39:25 +00:00
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
import torch
import torch.nn as nn
class DeterministicTransitionModel(nn.Module):
def __init__(self, encoder_feature_dim, action_shape, layer_width):
super().__init__()
self.fc = nn. Linear(encoder_feature_dim + action_shape[0], layer_width)
self.ln = nn.LayerNorm(layer_width)
self.fc_mu = nn.Linear(layer_width, encoder_feature_dim)
print("Deterministic transition model chosen.")
def forward(self, x):
x = self.fc(x)
x = self.ln(x)
x = torch.relu(x)
mu = self.fc_mu(x)
sigma = None
return mu, sigma
def sample_prediction(self, x):
mu, sigma = self(x)
return mu
class ProbabilisticTransitionModel(nn.Module):
def __init__(self, encoder_feature_dim, action_shape, layer_width, announce=True, max_sigma=1e1, min_sigma=1e-4):
super().__init__()
self.fc = nn. Linear(encoder_feature_dim + action_shape[0], layer_width)
self.ln = nn.LayerNorm(layer_width)
self.fc_mu = nn.Linear(layer_width, encoder_feature_dim)
self.fc_sigma = nn.Linear(layer_width, encoder_feature_dim)
self.max_sigma = max_sigma
self.min_sigma = min_sigma
assert(self.max_sigma >= self.min_sigma)
if announce:
print("Probabilistic transition model chosen.")
def forward(self, x):
x = self.fc(x)
x = self.ln(x)
x = torch.relu(x)
mu = self.fc_mu(x)
sigma = torch.sigmoid(self.fc_sigma(x)) # range (0, 1.)
sigma = self.min_sigma + (self.max_sigma - self.min_sigma) * sigma # scaled range (min_sigma, max_sigma)
return mu, sigma
def sample_prediction(self, x):
mu, sigma = self(x)
eps = torch.randn_like(sigma)
return mu + sigma * eps
class EnsembleOfProbabilisticTransitionModels(object):
def __init__(self, encoder_feature_dim, action_shape, layer_width, ensemble_size=5):
self.models = [ProbabilisticTransitionModel(encoder_feature_dim, action_shape, layer_width, announce=False)
for _ in range(ensemble_size)]
print("Ensemble of probabilistic transition models chosen.")
def __call__(self, x):
mu_sigma_list = [model.forward(x) for model in self.models]
mus, sigmas = zip(*mu_sigma_list)
mus, sigmas = torch.stack(mus), torch.stack(sigmas)
return mus, sigmas
def sample_prediction(self, x):
model = random.choice(self.models)
return model.sample_prediction(x)
def to(self, device):
for model in self.models:
model.to(device)
return self
def parameters(self):
list_of_parameters = [list(model.parameters()) for model in self.models]
parameters = [p for ps in list_of_parameters for p in ps]
return parameters
_AVAILABLE_TRANSITION_MODELS = {'': DeterministicTransitionModel,
'deterministic': DeterministicTransitionModel,
'probabilistic': ProbabilisticTransitionModel,
'ensemble': EnsembleOfProbabilisticTransitionModels}
def make_transition_model(transition_model_type, encoder_feature_dim, action_shape, layer_width=512):
assert transition_model_type in _AVAILABLE_TRANSITION_MODELS
return _AVAILABLE_TRANSITION_MODELS[transition_model_type](
encoder_feature_dim, action_shape, layer_width
)