Neural Network Models
This commit is contained in:
parent
bc1b46247d
commit
18dd8cc8cf
159
models.py
Normal file
159
models.py
Normal file
@ -0,0 +1,159 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as f
|
||||
from torch.distributions import Categorical
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, channels, encoded_state_size):
|
||||
super(Encoder, self).__init__()
|
||||
self.channels = channels
|
||||
self.encoded_state_size = encoded_state_size
|
||||
|
||||
self.feature_encoder = nn.Sequential(
|
||||
nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Flatten(),
|
||||
nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size),
|
||||
).to(device)
|
||||
|
||||
def forward(self, state):
|
||||
if state.dim() == 3:
|
||||
state = state.unsqueeze(0)
|
||||
state = self.feature_encoder(state)
|
||||
return state
|
||||
|
||||
|
||||
class InverseModel(nn.Module):
|
||||
def __init__(self, encoded_state_size, action_size=2):
|
||||
super(InverseModel, self).__init__()
|
||||
self.encoded_state_size = encoded_state_size
|
||||
self.action_size = action_size
|
||||
|
||||
self.model = nn.Sequential(
|
||||
nn.Linear(in_features=self.encoded_state_size*2, out_features=256),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(in_features=256, out_features=self.action_size),
|
||||
nn.Softmax(dim=-1)
|
||||
).to(device)
|
||||
|
||||
def forward(self, encoded_state, next_encoded_state):
|
||||
encoded_states = torch.cat((encoded_state, next_encoded_state), dim=-1)
|
||||
actions = Categorical(self.model(encoded_states))
|
||||
return actions
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
class ForwardModel(nn.Module):
|
||||
def __init__(self, encoded_state_size, action_size):
|
||||
super(ForwardModel, self).__init__()
|
||||
self.encoded_state_size = encoded_state_size
|
||||
self.action_size = action_size
|
||||
|
||||
self.model = nn.Sequential(
|
||||
nn.Linear(in_features=self.encoded_state_size+self.action_size, out_features=256),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(in_features=256, out_features=self.encoded_state_size),
|
||||
).to(device)
|
||||
|
||||
def forward(self, state, action):
|
||||
state = torch.cat((state, action), dim=-1)
|
||||
return self.model(state)
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
def __init__(self,encoded_state_size, action_size, state_size=4):
|
||||
super(Actor, self).__init__()
|
||||
self.channels = state_size
|
||||
self.encoded_state_size = encoded_state_size
|
||||
self.action_size = action_size
|
||||
|
||||
self.feature_encoder = nn.Sequential(
|
||||
nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Flatten(),
|
||||
nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size),
|
||||
).to(device)
|
||||
|
||||
def actor(self,state):
|
||||
policy = nn.Sequential(
|
||||
nn.Linear(in_features=self.encoded_state_size , out_features=256),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(in_features=256, out_features=self.action_size),
|
||||
nn.Softmax(dim=-1)
|
||||
).to(device)
|
||||
return policy(state)
|
||||
|
||||
def forward(self, state):
|
||||
state = self.feature_encoder(state)
|
||||
policy = self.actor(state)
|
||||
actions = Categorical(policy)
|
||||
return actions
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(self, encoded_state_size, state_size=4):
|
||||
super(Critic, self).__init__()
|
||||
self.channels = state_size
|
||||
self.encoded_state_size = encoded_state_size
|
||||
|
||||
self.feature_encoder = nn.Sequential(
|
||||
nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
||||
nn.LeakyReLU(),
|
||||
nn.Flatten(),
|
||||
nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size),
|
||||
).to(device)
|
||||
|
||||
def critic(self,state):
|
||||
value = nn.Sequential(
|
||||
nn.Linear(in_features=self.encoded_state_size , out_features=256),
|
||||
nn.LeakyReLU(),
|
||||
nn.Linear(in_features=256, out_features=1),
|
||||
).to(device)
|
||||
return value(state)
|
||||
|
||||
def forward(self, state):
|
||||
state = self.feature_encoder(state)
|
||||
value = self.critic(state)
|
||||
return value
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
Loading…
Reference in New Issue
Block a user