diff --git a/DPI/models.py b/DPI/models.py index 2158ea7..0efdaeb 100644 --- a/DPI/models.py +++ b/DPI/models.py @@ -93,7 +93,7 @@ class ObservationDecoder(nn.Module): return out_dist -class ActionDecoder(nn.Module): +class Actor(nn.Module): def __init__(self, state_size, hidden_size, action_size, num_layers=5): super().__init__() self.state_size = state_size @@ -151,8 +151,24 @@ class ValueModel(nn.Module): value = self.value_model(state) value_dist = torch.distributions.independent.Independent(torch.distributions.Normal(value, 1), 1) return value_dist - - + + +class RewardModel(nn.Module): + def __init__(self, state_size, hidden_size): + super().__init__() + self.reward_model = nn.Sequential( + nn.Linear(state_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, 1) + ) + + def forward(self, state): + reward = self.reward_model(state).squeeze(dim=1) + return reward + + class TransitionModel(nn.Module): def __init__(self, state_size, hidden_size, action_size, history_size): super().__init__() @@ -194,8 +210,7 @@ class TransitionModel(nn.Module): prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist} return prior - def stack_states(self, states, dim=0): - + def stack_states(self, states, dim=0): s = dict( mean = torch.stack([state['mean'] for state in states], dim=dim), std = torch.stack([state['std'] for state in states], dim=dim),