Adding model architecture for Reward, Value and Target Value

This commit is contained in:
Vedant Dave 2023-04-10 13:18:41 +02:00
parent 47090449d1
commit 8fd56ba94d

View File

@ -93,7 +93,7 @@ class ObservationDecoder(nn.Module):
return out_dist return out_dist
class ActionDecoder(nn.Module): class Actor(nn.Module):
def __init__(self, state_size, hidden_size, action_size, num_layers=5): def __init__(self, state_size, hidden_size, action_size, num_layers=5):
super().__init__() super().__init__()
self.state_size = state_size self.state_size = state_size
@ -151,8 +151,24 @@ class ValueModel(nn.Module):
value = self.value_model(state) value = self.value_model(state)
value_dist = torch.distributions.independent.Independent(torch.distributions.Normal(value, 1), 1) value_dist = torch.distributions.independent.Independent(torch.distributions.Normal(value, 1), 1)
return value_dist return value_dist
class RewardModel(nn.Module):
def __init__(self, state_size, hidden_size):
super().__init__()
self.reward_model = nn.Sequential(
nn.Linear(state_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1)
)
def forward(self, state):
reward = self.reward_model(state).squeeze(dim=1)
return reward
class TransitionModel(nn.Module): class TransitionModel(nn.Module):
def __init__(self, state_size, hidden_size, action_size, history_size): def __init__(self, state_size, hidden_size, action_size, history_size):
super().__init__() super().__init__()
@ -194,8 +210,7 @@ class TransitionModel(nn.Module):
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist} prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist}
return prior return prior
def stack_states(self, states, dim=0): def stack_states(self, states, dim=0):
s = dict( s = dict(
mean = torch.stack([state['mean'] for state in states], dim=dim), mean = torch.stack([state['mean'] for state in states], dim=dim),
std = torch.stack([state['std'] for state in states], dim=dim), std = torch.stack([state['std'] for state in states], dim=dim),