diff --git a/models.py b/models.py new file mode 100644 index 0000000..47324a3 --- /dev/null +++ b/models.py @@ -0,0 +1,159 @@ +import torch +import torch.nn as nn +import torch.nn.functional as f +from torch.distributions import Categorical + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class Encoder(nn.Module): + def __init__(self, channels, encoded_state_size): + super(Encoder, self).__init__() + self.channels = channels + self.encoded_state_size = encoded_state_size + + self.feature_encoder = nn.Sequential( + nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Flatten(), + nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size), + ).to(device) + + def forward(self, state): + if state.dim() == 3: + state = state.unsqueeze(0) + state = self.feature_encoder(state) + return state + + +class InverseModel(nn.Module): + def __init__(self, encoded_state_size, action_size=2): + super(InverseModel, self).__init__() + self.encoded_state_size = encoded_state_size + self.action_size = action_size + + self.model = nn.Sequential( + nn.Linear(in_features=self.encoded_state_size*2, out_features=256), + nn.LeakyReLU(), + nn.Linear(in_features=256, out_features=self.action_size), + nn.Softmax(dim=-1) + ).to(device) + + def forward(self, encoded_state, next_encoded_state): + encoded_states = torch.cat((encoded_state, next_encoded_state), dim=-1) + actions = Categorical(self.model(encoded_states)) + return actions + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + + +class ForwardModel(nn.Module): + def __init__(self, encoded_state_size, action_size): + super(ForwardModel, self).__init__() + self.encoded_state_size = encoded_state_size + self.action_size = action_size + + self.model = nn.Sequential( + nn.Linear(in_features=self.encoded_state_size+self.action_size, out_features=256), + nn.LeakyReLU(), + nn.Linear(in_features=256, out_features=self.encoded_state_size), + ).to(device) + + def forward(self, state, action): + state = torch.cat((state, action), dim=-1) + return self.model(state) + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + + +class Actor(nn.Module): + def __init__(self,encoded_state_size, action_size, state_size=4): + super(Actor, self).__init__() + self.channels = state_size + self.encoded_state_size = encoded_state_size + self.action_size = action_size + + self.feature_encoder = nn.Sequential( + nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Flatten(), + nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size), + ).to(device) + + def actor(self,state): + policy = nn.Sequential( + nn.Linear(in_features=self.encoded_state_size , out_features=256), + nn.LeakyReLU(), + nn.Linear(in_features=256, out_features=self.action_size), + nn.Softmax(dim=-1) + ).to(device) + return policy(state) + + def forward(self, state): + state = self.feature_encoder(state) + policy = self.actor(state) + actions = Categorical(policy) + return actions + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + +class Critic(nn.Module): + def __init__(self, encoded_state_size, state_size=4): + super(Critic, self).__init__() + self.channels = state_size + self.encoded_state_size = encoded_state_size + + self.feature_encoder = nn.Sequential( + nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2), + nn.LeakyReLU(), + nn.Flatten(), + nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size), + ).to(device) + + def critic(self,state): + value = nn.Sequential( + nn.Linear(in_features=self.encoded_state_size , out_features=256), + nn.LeakyReLU(), + nn.Linear(in_features=256, out_features=1), + ).to(device) + return value(state) + + def forward(self, state): + state = self.feature_encoder(state) + value = self.critic(state) + return value + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) \ No newline at end of file