Curiosity/ppo/models.py
2023-02-01 19:36:09 +01:00

147 lines
5.2 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as f
from torch.distributions import Categorical
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ICM(nn.Module):
def __init__(self, channels, encoded_state_size, action_size):
super(ICM, self).__init__()
self.channels = channels
self.encoded_state_size = encoded_state_size
self.action_size = action_size
self.feature_encoder = nn.Sequential(
nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Flatten(),
nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size),
).to(device)
self.inverse_model = nn.Sequential(
nn.Linear(in_features=self.encoded_state_size*2, out_features=256),
nn.LeakyReLU(),
nn.Linear(in_features=256, out_features=self.action_size),
nn.Softmax(dim=-1)
).to(device)
self.forward_model = nn.Sequential(
nn.Linear(in_features=self.encoded_state_size+self.action_size, out_features=256),
nn.LeakyReLU(),
nn.Linear(in_features=256, out_features=self.encoded_state_size),
).to(device)
def forward(self, state, next_state, action):
if state.dim() == 3:
state = state.unsqueeze(0)
next_state = next_state.unsqueeze(0)
encoded_state = self.feature_encoder(state)
next_encoded_state = self.feature_encoder(next_state)
action_pred = self.inverse_model(torch.cat((encoded_state, next_encoded_state), dim=-1))
next_encoded_state_pred = self.forward_model(torch.cat((encoded_state, action), dim=-1))
return encoded_state, next_encoded_state, action_pred, next_encoded_state_pred
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
class ActorCritic(nn.Module):
def __init__(self,encoded_state_size, action_size, state_size=4):
super(ActorCritic, self).__init__()
self.channels = state_size
self.encoded_state_size = encoded_state_size
self.action_size = action_size
self.feature_encoder = nn.Sequential(
nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Flatten(),
nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size),
).to(device)
def actor(self,state):
policy = nn.Sequential(
nn.Linear(in_features=self.encoded_state_size , out_features=256),
nn.LeakyReLU(),
nn.Linear(in_features=256, out_features=self.action_size),
nn.Softmax(dim=-1)
).to(device)
return policy(state)
def critic(self,state):
value = nn.Sequential(
nn.Linear(in_features=self.encoded_state_size , out_features=256),
nn.LeakyReLU(),
nn.Linear(in_features=256, out_features=1),
).to(device)
return value(state)
def forward(self, state):
if state.dim() == 3:
state = state.unsqueeze(0)
state = self.feature_encoder(state)
value = self.critic(state)
policy = self.actor(state)
actions = Categorical(policy)
return actions, value
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
# Policy and value model
class ActorCriticNetwork(nn.Module):
def __init__(self, obs_space_size, action_space_size):
super().__init__()
self.shared_layers = nn.Sequential(
nn.Linear(obs_space_size, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU()).to(device)
self.policy_layers = nn.Sequential(
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, action_space_size),
nn.Softmax(dim=-1)).to(device)
self.value_layers = nn.Sequential(
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 1)).to(device)
def value(self, obs):
z = self.shared_layers(obs)
value = self.value_layers(z)
return value
def policy(self, obs):
z = self.shared_layers(obs)
policy_logits = self.policy_layers(z)
return policy_logits
def forward(self, obs):
z = self.shared_layers(obs)
policy_logits = self.policy_layers(z)
value = self.value_layers(z)
return Categorical(policy_logits), value