import torch import torch.nn as nn import torch.nn.functional as f from torch.distributions import Categorical device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class Encoder(nn.Module): def __init__(self, channels, encoded_state_size): super(Encoder, self).__init__() self.channels = channels self.encoded_state_size = encoded_state_size self.feature_encoder = nn.Sequential( nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2), nn.LeakyReLU(), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), nn.LeakyReLU(), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), nn.LeakyReLU(), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), nn.LeakyReLU(), nn.Flatten(), nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size), ).to(device) def forward(self, state): if state.dim() == 3: state = state.unsqueeze(0) state = self.feature_encoder(state) return state class InverseModel(nn.Module): def __init__(self, encoded_state_size, action_size=2): super(InverseModel, self).__init__() self.encoded_state_size = encoded_state_size self.action_size = action_size self.model = nn.Sequential( nn.Linear(in_features=self.encoded_state_size*2, out_features=256), nn.LeakyReLU(), nn.Linear(in_features=256, out_features=self.action_size), nn.Softmax(dim=-1) ).to(device) def forward(self, encoded_state, next_encoded_state): encoded_states = torch.cat((encoded_state, next_encoded_state), dim=-1) actions = Categorical(self.model(encoded_states)) return actions def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.zeros_(m.bias) class ForwardModel(nn.Module): def __init__(self, encoded_state_size, action_size): super(ForwardModel, self).__init__() self.encoded_state_size = encoded_state_size self.action_size = action_size self.model = nn.Sequential( nn.Linear(in_features=self.encoded_state_size+self.action_size, out_features=256), nn.LeakyReLU(), nn.Linear(in_features=256, out_features=self.encoded_state_size), ).to(device) def forward(self, state, action): state = torch.cat((state, action), dim=-1) return self.model(state) def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.zeros_(m.bias) class Actor(nn.Module): def __init__(self,encoded_state_size, action_size, state_size=4): super(Actor, self).__init__() self.channels = state_size self.encoded_state_size = encoded_state_size self.action_size = action_size self.feature_encoder = nn.Sequential( nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2), nn.LeakyReLU(), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), nn.LeakyReLU(), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), nn.LeakyReLU(), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), nn.LeakyReLU(), nn.Flatten(), nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size), ).to(device) def actor(self,state): policy = nn.Sequential( nn.Linear(in_features=self.encoded_state_size , out_features=256), nn.LeakyReLU(), nn.Linear(in_features=256, out_features=self.action_size), nn.Softmax(dim=-1) ).to(device) return policy(state) def forward(self, state): state = self.feature_encoder(state) policy = self.actor(state) actions = Categorical(policy) return actions def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.zeros_(m.bias) class Critic(nn.Module): def __init__(self, encoded_state_size, state_size=4): super(Critic, self).__init__() self.channels = state_size self.encoded_state_size = encoded_state_size self.feature_encoder = nn.Sequential( nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2), nn.LeakyReLU(), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), nn.LeakyReLU(), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), nn.LeakyReLU(), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), nn.LeakyReLU(), nn.Flatten(), nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size), ).to(device) def critic(self,state): value = nn.Sequential( nn.Linear(in_features=self.encoded_state_size , out_features=256), nn.LeakyReLU(), nn.Linear(in_features=256, out_features=1), ).to(device) return value(state) def forward(self, state): state = self.feature_encoder(state) value = self.critic(state) return value def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.zeros_(m.bias)