Adding model architecture for Reward, Value and Target Value

This commit is contained in:
Vedant Dave 2023-04-10 13:18:41 +02:00
parent 47090449d1
commit 8fd56ba94d

View File

@ -93,7 +93,7 @@ class ObservationDecoder(nn.Module):
return out_dist return out_dist
class ActionDecoder(nn.Module): class Actor(nn.Module):
def __init__(self, state_size, hidden_size, action_size, num_layers=5): def __init__(self, state_size, hidden_size, action_size, num_layers=5):
super().__init__() super().__init__()
self.state_size = state_size self.state_size = state_size
@ -153,6 +153,22 @@ class ValueModel(nn.Module):
return value_dist return value_dist
class RewardModel(nn.Module):
def __init__(self, state_size, hidden_size):
super().__init__()
self.reward_model = nn.Sequential(
nn.Linear(state_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1)
)
def forward(self, state):
reward = self.reward_model(state).squeeze(dim=1)
return reward
class TransitionModel(nn.Module): class TransitionModel(nn.Module):
def __init__(self, state_size, hidden_size, action_size, history_size): def __init__(self, state_size, hidden_size, action_size, history_size):
super().__init__() super().__init__()
@ -195,7 +211,6 @@ class TransitionModel(nn.Module):
return prior return prior
def stack_states(self, states, dim=0): def stack_states(self, states, dim=0):
s = dict( s = dict(
mean = torch.stack([state['mean'] for state in states], dim=dim), mean = torch.stack([state['mean'] for state in states], dim=dim),
std = torch.stack([state['std'] for state in states], dim=dim), std = torch.stack([state['std'] for state in states], dim=dim),