Compare commits
No commits in common. "c3f6e9f281ed5bd6a5afc77ab12f20bca4e023f8" and "0781d4fd05b8fc017eb7cee6b54c81d94b81735b" have entirely different histories.
c3f6e9f281
...
0781d4fd05
196
icm cartpole.py
196
icm cartpole.py
@ -1,196 +0,0 @@
|
|||||||
import gym
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim
|
|
||||||
import collections
|
|
||||||
|
|
||||||
env = gym.make('CartPole-v1')
|
|
||||||
|
|
||||||
class Actor(nn.Module):
|
|
||||||
def __init__(self, n_actions, space_dims, hidden_dims):
|
|
||||||
super(Actor, self).__init__()
|
|
||||||
self.feature_extractor = nn.Sequential(
|
|
||||||
nn.Linear(space_dims, hidden_dims),
|
|
||||||
nn.ReLU(True),
|
|
||||||
)
|
|
||||||
self.actor = nn.Sequential(
|
|
||||||
nn.Linear(hidden_dims, n_actions),
|
|
||||||
nn.Softmax(dim=-1),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
features = self.feature_extractor(x)
|
|
||||||
policy = self.actor(features)
|
|
||||||
return policy
|
|
||||||
|
|
||||||
class Critic(nn.Module):
|
|
||||||
def __init__(self, space_dims, hidden_dims):
|
|
||||||
super(Critic, self).__init__()
|
|
||||||
self.feature_extractor = nn.Sequential(
|
|
||||||
nn.Linear(space_dims, hidden_dims),
|
|
||||||
nn.ReLU(True),
|
|
||||||
)
|
|
||||||
self.critic = nn.Linear(hidden_dims, 1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
features = self.feature_extractor(x)
|
|
||||||
est_reward = self.critic(features)
|
|
||||||
return est_reward
|
|
||||||
|
|
||||||
class InverseModel(nn.Module):
|
|
||||||
def __init__(self, n_actions, hidden_dims):
|
|
||||||
super(InverseModel, self).__init__()
|
|
||||||
self.fc = nn.Linear(hidden_dims*2, n_actions)
|
|
||||||
|
|
||||||
def forward(self, features):
|
|
||||||
features = features.view(1, -1) # (1, hidden_dims)
|
|
||||||
action = self.fc(features) # (1, n_actions)
|
|
||||||
return action
|
|
||||||
|
|
||||||
class ForwardModel(nn.Module):
|
|
||||||
def __init__(self, n_actions, hidden_dims):
|
|
||||||
super(ForwardModel, self).__init__()
|
|
||||||
self.fc = nn.Linear(hidden_dims+n_actions, hidden_dims)
|
|
||||||
self.eye = torch.eye(n_actions)
|
|
||||||
|
|
||||||
def forward(self, action, features):
|
|
||||||
x = torch.cat([self.eye[action], features], dim=-1) # (1, n_actions+hidden_dims)
|
|
||||||
features = self.fc(x) # (1, hidden_dims)
|
|
||||||
return features
|
|
||||||
|
|
||||||
class FeatureExtractor(nn.Module):
|
|
||||||
def __init__(self, space_dims, hidden_dims):
|
|
||||||
super(FeatureExtractor, self).__init__()
|
|
||||||
self.fc = nn.Linear(space_dims, hidden_dims)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
y = torch.tanh(self.fc(x))
|
|
||||||
return y
|
|
||||||
|
|
||||||
class PGLoss(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(PGLoss, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, action_prob, reward):
|
|
||||||
loss = -torch.mean(torch.log(action_prob+1e-6)*reward)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def select_action(policy):
|
|
||||||
return np.random.choice(len(policy), 1, p=policy)[0]
|
|
||||||
|
|
||||||
def to_tensor(x, dtype=None):
|
|
||||||
return torch.tensor(x, dtype=dtype).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigArgs:
|
|
||||||
beta = 0.2
|
|
||||||
lamda = 0.1
|
|
||||||
eta = 100.0 # scale factor for intrinsic reward
|
|
||||||
discounted_factor = 0.99
|
|
||||||
lr_critic = 0.005
|
|
||||||
lr_actor = 0.001
|
|
||||||
lr_icm = 0.001
|
|
||||||
max_eps = 1000
|
|
||||||
sparse_mode = True
|
|
||||||
|
|
||||||
args = ConfigArgs()
|
|
||||||
# Actor Critic
|
|
||||||
actor = Actor(n_actions=env.action_space.n, space_dims=4, hidden_dims=32)
|
|
||||||
critic = Critic(space_dims=4, hidden_dims=32)
|
|
||||||
# ICM
|
|
||||||
feature_extractor = FeatureExtractor(env.observation_space.shape[0], 32)
|
|
||||||
forward_model = ForwardModel(env.action_space.n, 32)
|
|
||||||
inverse_model = InverseModel(env.action_space.n, 32)
|
|
||||||
# Actor Critic
|
|
||||||
a_optim = torch.optim.Adam(actor.parameters(), lr=args.lr_actor)
|
|
||||||
c_optim = torch.optim.Adam(critic.parameters(), lr=args.lr_critic)
|
|
||||||
|
|
||||||
# ICM
|
|
||||||
icm_params = list(feature_extractor.parameters()) + list(forward_model.parameters()) + list(inverse_model.parameters())
|
|
||||||
icm_optim = torch.optim.Adam(icm_params, lr=args.lr_icm)
|
|
||||||
pg_loss = PGLoss()
|
|
||||||
mse_loss = nn.MSELoss()
|
|
||||||
xe_loss = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
global_step = 0
|
|
||||||
n_eps = 0
|
|
||||||
reward_lst = []
|
|
||||||
mva_lst = []
|
|
||||||
mva = 0.
|
|
||||||
avg_ireward_lst = []
|
|
||||||
|
|
||||||
while n_eps < args.max_eps:
|
|
||||||
n_eps += 1
|
|
||||||
next_obs = to_tensor(env.reset(), dtype=torch.float)
|
|
||||||
done = False
|
|
||||||
score = 0
|
|
||||||
ireward_lst = []
|
|
||||||
|
|
||||||
while not done:
|
|
||||||
env.render()
|
|
||||||
obs = next_obs
|
|
||||||
a_optim.zero_grad()
|
|
||||||
c_optim.zero_grad()
|
|
||||||
icm_optim.zero_grad()
|
|
||||||
|
|
||||||
# estimate action with policy network
|
|
||||||
policy = actor(obs)
|
|
||||||
action = select_action(policy.detach().numpy()[0])
|
|
||||||
|
|
||||||
# interaction with environment
|
|
||||||
next_obs, reward, done, info = env.step(action)
|
|
||||||
next_obs = to_tensor(next_obs, dtype=torch.float)
|
|
||||||
advantages = torch.zeros_like(policy)
|
|
||||||
extrinsic_reward = to_tensor([0.], dtype=torch.float) if args.sparse_mode else to_tensor([reward], dtype=torch.float)
|
|
||||||
t_action = to_tensor(action)
|
|
||||||
|
|
||||||
v = critic(obs)[0]
|
|
||||||
next_v = critic(next_obs)[0]
|
|
||||||
|
|
||||||
# ICM
|
|
||||||
obs_cat = torch.cat([obs, next_obs], dim=0)
|
|
||||||
features = feature_extractor(obs_cat) # (2, hidden_dims)
|
|
||||||
inverse_action_prob = inverse_model(features) # (n_actions)
|
|
||||||
est_next_features = forward_model(t_action, features[0:1])
|
|
||||||
|
|
||||||
# Loss - ICM
|
|
||||||
forward_loss = mse_loss(est_next_features, features[1])
|
|
||||||
inverse_loss = xe_loss(inverse_action_prob, t_action.view(-1))
|
|
||||||
icm_loss = (1-args.beta)*inverse_loss + args.beta*forward_loss
|
|
||||||
|
|
||||||
# Reward
|
|
||||||
intrinsic_reward = args.eta*forward_loss.detach()
|
|
||||||
if done:
|
|
||||||
total_reward = -100 + intrinsic_reward if score < 499 else intrinsic_reward
|
|
||||||
advantages[0, action] = total_reward - v
|
|
||||||
c_target = total_reward
|
|
||||||
else:
|
|
||||||
total_reward = extrinsic_reward + intrinsic_reward
|
|
||||||
advantages[0, action] = total_reward + args.discounted_factor*next_v - v
|
|
||||||
c_target = total_reward + args.discounted_factor*next_v
|
|
||||||
|
|
||||||
# Loss - Actor Critic
|
|
||||||
actor_loss = pg_loss(policy, advantages.detach())
|
|
||||||
critic_loss = mse_loss(v, c_target.detach())
|
|
||||||
ac_loss = actor_loss + critic_loss
|
|
||||||
|
|
||||||
# Update
|
|
||||||
loss = args.lamda*ac_loss + icm_loss
|
|
||||||
loss.backward()
|
|
||||||
icm_optim.step()
|
|
||||||
a_optim.step()
|
|
||||||
c_optim.step()
|
|
||||||
|
|
||||||
if not done:
|
|
||||||
score += reward
|
|
||||||
|
|
||||||
ireward_lst.append(intrinsic_reward.item())
|
|
||||||
|
|
||||||
global_step += 1
|
|
||||||
avg_intrinsic_reward = sum(ireward_lst) / len(ireward_lst)
|
|
||||||
mva = 0.95*mva + 0.05*score
|
|
||||||
reward_lst.append(score)
|
|
||||||
avg_ireward_lst.append(avg_intrinsic_reward)
|
|
||||||
mva_lst.append(mva)
|
|
||||||
print('Episodes: {}, AVG Score: {:.3f}, Score: {}, AVG reward i: {:.6f}'.format(n_eps, mva, score, avg_intrinsic_reward))
|
|
175
mario_env.py
175
mario_env.py
@ -1,175 +0,0 @@
|
|||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import collections
|
|
||||||
|
|
||||||
import gym
|
|
||||||
from gym.spaces import Box
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torchvision import transforms as T
|
|
||||||
|
|
||||||
import gym_super_mario_bros
|
|
||||||
from nes_py.wrappers import JoypadSpace
|
|
||||||
from gym_super_mario_bros.actions import RIGHT_ONLY, SIMPLE_MOVEMENT, COMPLEX_MOVEMENT
|
|
||||||
|
|
||||||
|
|
||||||
class SkipFrame(gym.Wrapper):
|
|
||||||
def __init__(self, env, skip):
|
|
||||||
"""Return only every `skip`-th frame"""
|
|
||||||
super().__init__(env)
|
|
||||||
self._skip = skip
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
"""Repeat action, and sum reward"""
|
|
||||||
total_reward = 0.0
|
|
||||||
for i in range(self._skip):
|
|
||||||
# Accumulate reward and repeat the same action
|
|
||||||
obs, reward, done, trunk, info = self.env.step(action)
|
|
||||||
total_reward += reward
|
|
||||||
if done:
|
|
||||||
break
|
|
||||||
return obs, total_reward, done, trunk, info
|
|
||||||
|
|
||||||
|
|
||||||
class GrayScaleObservation(gym.ObservationWrapper):
|
|
||||||
def __init__(self, env):
|
|
||||||
super().__init__(env)
|
|
||||||
obs_shape = self.observation_space.shape[:2]
|
|
||||||
self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
|
|
||||||
|
|
||||||
def permute_orientation(self, observation):
|
|
||||||
# permute [H, W, C] array to [C, H, W] tensor
|
|
||||||
observation = np.transpose(observation, (2, 0, 1))
|
|
||||||
observation = torch.tensor(observation.copy(), dtype=torch.float)
|
|
||||||
return observation
|
|
||||||
|
|
||||||
def observation(self, observation):
|
|
||||||
observation = self.permute_orientation(observation)
|
|
||||||
transform = T.Grayscale()
|
|
||||||
observation = transform(observation)
|
|
||||||
return observation
|
|
||||||
|
|
||||||
|
|
||||||
class ResizeObservation(gym.ObservationWrapper):
|
|
||||||
def __init__(self, env, shape):
|
|
||||||
super().__init__(env)
|
|
||||||
if isinstance(shape, int):
|
|
||||||
self.shape = (shape, shape)
|
|
||||||
else:
|
|
||||||
self.shape = tuple(shape)
|
|
||||||
|
|
||||||
obs_shape = self.shape + self.observation_space.shape[2:]
|
|
||||||
self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
|
|
||||||
|
|
||||||
def observation(self, observation):
|
|
||||||
transforms = T.Compose(
|
|
||||||
[T.Resize(self.shape), T.Normalize(0, 255)]
|
|
||||||
)
|
|
||||||
observation = transforms(observation).squeeze(0)
|
|
||||||
return observation
|
|
||||||
class MaxAndSkipEnv(gym.Wrapper):
|
|
||||||
"""
|
|
||||||
Each action of the agent is repeated over skip frames
|
|
||||||
return only every `skip`-th frame
|
|
||||||
"""
|
|
||||||
def __init__(self, env=None, skip=4):
|
|
||||||
super(MaxAndSkipEnv, self).__init__(env)
|
|
||||||
# most recent raw observations (for max pooling across time steps)
|
|
||||||
self._obs_buffer = collections.deque(maxlen=2)
|
|
||||||
self._skip = skip
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
total_reward = 0.0
|
|
||||||
done = None
|
|
||||||
for _ in range(self._skip):
|
|
||||||
obs, reward, done, info = self.env.step(action)
|
|
||||||
self._obs_buffer.append(obs)
|
|
||||||
total_reward += reward
|
|
||||||
if done:
|
|
||||||
break
|
|
||||||
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
|
|
||||||
return max_frame, total_reward, done, info
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""Clear past frame buffer and init to first obs"""
|
|
||||||
self._obs_buffer.clear()
|
|
||||||
obs = self.env.reset()
|
|
||||||
self._obs_buffer.append(obs)
|
|
||||||
return obs
|
|
||||||
|
|
||||||
|
|
||||||
class MarioRescale84x84(gym.ObservationWrapper):
|
|
||||||
"""
|
|
||||||
Downsamples/Rescales each frame to size 84x84 with greyscale
|
|
||||||
"""
|
|
||||||
def __init__(self, env=None):
|
|
||||||
super(MarioRescale84x84, self).__init__(env)
|
|
||||||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
|
|
||||||
|
|
||||||
def observation(self, obs):
|
|
||||||
return MarioRescale84x84.process(obs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def process(frame):
|
|
||||||
if frame.size == 240 * 256 * 3:
|
|
||||||
img = np.reshape(frame, [240, 256, 3]).astype(np.float32)
|
|
||||||
else:
|
|
||||||
assert False, "Unknown resolution."
|
|
||||||
# image normalization on RBG
|
|
||||||
img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
|
|
||||||
resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
|
|
||||||
x_t = resized_screen[18:102, :]
|
|
||||||
x_t = np.reshape(x_t, [84, 84, 1])
|
|
||||||
return x_t.astype(np.uint8)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageToPyTorch(gym.ObservationWrapper):
|
|
||||||
"""
|
|
||||||
Each frame is converted to PyTorch tensors
|
|
||||||
"""
|
|
||||||
def __init__(self, env):
|
|
||||||
super(ImageToPyTorch, self).__init__(env)
|
|
||||||
old_shape = self.observation_space.shape
|
|
||||||
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)
|
|
||||||
|
|
||||||
def observation(self, observation):
|
|
||||||
return np.moveaxis(observation, 2, 0)
|
|
||||||
|
|
||||||
|
|
||||||
class BufferWrapper(gym.ObservationWrapper):
|
|
||||||
"""
|
|
||||||
Only every k-th frame is collected by the buffer
|
|
||||||
"""
|
|
||||||
def __init__(self, env, n_steps, dtype=np.float32):
|
|
||||||
super(BufferWrapper, self).__init__(env)
|
|
||||||
self.dtype = dtype
|
|
||||||
old_space = env.observation_space
|
|
||||||
self.observation_space = gym.spaces.Box(old_space.low.repeat(n_steps, axis=0),
|
|
||||||
old_space.high.repeat(n_steps, axis=0), dtype=dtype)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype)
|
|
||||||
return self.observation(self.env.reset())
|
|
||||||
|
|
||||||
def observation(self, observation):
|
|
||||||
self.buffer[:-1] = self.buffer[1:]
|
|
||||||
self.buffer[-1] = observation
|
|
||||||
return self.buffer
|
|
||||||
|
|
||||||
|
|
||||||
class PixelNormalization(gym.ObservationWrapper):
|
|
||||||
"""
|
|
||||||
Normalize pixel values in frame --> 0 to 1
|
|
||||||
"""
|
|
||||||
def observation(self, obs):
|
|
||||||
return np.array(obs).astype(np.float32) / 255.0
|
|
||||||
|
|
||||||
|
|
||||||
def create_mario_env(env):
|
|
||||||
env = MaxAndSkipEnv(env)
|
|
||||||
env = MarioRescale84x84(env)
|
|
||||||
env = ImageToPyTorch(env)
|
|
||||||
env = BufferWrapper(env, 4)
|
|
||||||
env = PixelNormalization(env)
|
|
||||||
return JoypadSpace(env, COMPLEX_MOVEMENT)
|
|
159
models.py
159
models.py
@ -1,159 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as f
|
|
||||||
from torch.distributions import Categorical
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
|
||||||
def __init__(self, channels, encoded_state_size):
|
|
||||||
super(Encoder, self).__init__()
|
|
||||||
self.channels = channels
|
|
||||||
self.encoded_state_size = encoded_state_size
|
|
||||||
|
|
||||||
self.feature_encoder = nn.Sequential(
|
|
||||||
nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.Flatten(),
|
|
||||||
nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size),
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
def forward(self, state):
|
|
||||||
if state.dim() == 3:
|
|
||||||
state = state.unsqueeze(0)
|
|
||||||
state = self.feature_encoder(state)
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
class InverseModel(nn.Module):
|
|
||||||
def __init__(self, encoded_state_size, action_size=2):
|
|
||||||
super(InverseModel, self).__init__()
|
|
||||||
self.encoded_state_size = encoded_state_size
|
|
||||||
self.action_size = action_size
|
|
||||||
|
|
||||||
self.model = nn.Sequential(
|
|
||||||
nn.Linear(in_features=self.encoded_state_size*2, out_features=256),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.Linear(in_features=256, out_features=self.action_size),
|
|
||||||
nn.Softmax(dim=-1)
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
def forward(self, encoded_state, next_encoded_state):
|
|
||||||
encoded_states = torch.cat((encoded_state, next_encoded_state), dim=-1)
|
|
||||||
actions = Categorical(self.model(encoded_states))
|
|
||||||
return actions
|
|
||||||
|
|
||||||
def _init_weights(self):
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
nn.init.xavier_uniform_(m.weight)
|
|
||||||
nn.init.zeros_(m.bias)
|
|
||||||
|
|
||||||
|
|
||||||
class ForwardModel(nn.Module):
|
|
||||||
def __init__(self, encoded_state_size, action_size):
|
|
||||||
super(ForwardModel, self).__init__()
|
|
||||||
self.encoded_state_size = encoded_state_size
|
|
||||||
self.action_size = action_size
|
|
||||||
|
|
||||||
self.model = nn.Sequential(
|
|
||||||
nn.Linear(in_features=self.encoded_state_size+self.action_size, out_features=256),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.Linear(in_features=256, out_features=self.encoded_state_size),
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
def forward(self, state, action):
|
|
||||||
state = torch.cat((state, action), dim=-1)
|
|
||||||
return self.model(state)
|
|
||||||
|
|
||||||
def _init_weights(self):
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
nn.init.xavier_uniform_(m.weight)
|
|
||||||
nn.init.zeros_(m.bias)
|
|
||||||
|
|
||||||
|
|
||||||
class Actor(nn.Module):
|
|
||||||
def __init__(self,encoded_state_size, action_size, state_size=4):
|
|
||||||
super(Actor, self).__init__()
|
|
||||||
self.channels = state_size
|
|
||||||
self.encoded_state_size = encoded_state_size
|
|
||||||
self.action_size = action_size
|
|
||||||
|
|
||||||
self.feature_encoder = nn.Sequential(
|
|
||||||
nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.Flatten(),
|
|
||||||
nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size),
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
def actor(self,state):
|
|
||||||
policy = nn.Sequential(
|
|
||||||
nn.Linear(in_features=self.encoded_state_size , out_features=256),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.Linear(in_features=256, out_features=self.action_size),
|
|
||||||
nn.Softmax(dim=-1)
|
|
||||||
).to(device)
|
|
||||||
return policy(state)
|
|
||||||
|
|
||||||
def forward(self, state):
|
|
||||||
state = self.feature_encoder(state)
|
|
||||||
policy = self.actor(state)
|
|
||||||
actions = Categorical(policy)
|
|
||||||
return actions
|
|
||||||
|
|
||||||
def _init_weights(self):
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
nn.init.xavier_uniform_(m.weight)
|
|
||||||
nn.init.zeros_(m.bias)
|
|
||||||
|
|
||||||
class Critic(nn.Module):
|
|
||||||
def __init__(self, encoded_state_size, state_size=4):
|
|
||||||
super(Critic, self).__init__()
|
|
||||||
self.channels = state_size
|
|
||||||
self.encoded_state_size = encoded_state_size
|
|
||||||
|
|
||||||
self.feature_encoder = nn.Sequential(
|
|
||||||
nn.Conv2d(in_channels=self.channels, out_channels=32, kernel_size=3, stride=2),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.Flatten(),
|
|
||||||
nn.Linear(in_features=32*4*4, out_features=self.encoded_state_size),
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
def critic(self,state):
|
|
||||||
value = nn.Sequential(
|
|
||||||
nn.Linear(in_features=self.encoded_state_size , out_features=256),
|
|
||||||
nn.LeakyReLU(),
|
|
||||||
nn.Linear(in_features=256, out_features=1),
|
|
||||||
).to(device)
|
|
||||||
return value(state)
|
|
||||||
|
|
||||||
def forward(self, state):
|
|
||||||
state = self.feature_encoder(state)
|
|
||||||
value = self.critic(state)
|
|
||||||
return value
|
|
||||||
|
|
||||||
def _init_weights(self):
|
|
||||||
for m in self.modules():
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
nn.init.xavier_uniform_(m.weight)
|
|
||||||
nn.init.zeros_(m.bias)
|
|
Loading…
Reference in New Issue
Block a user