Compare commits

..

2 Commits

2 changed files with 43 additions and 7 deletions

View File

@ -93,7 +93,7 @@ class ObservationDecoder(nn.Module):
return out_dist
class ActionDecoder(nn.Module):
class Actor(nn.Module):
def __init__(self, state_size, hidden_size, action_size, num_layers=5):
super().__init__()
self.state_size = state_size
@ -153,6 +153,22 @@ class ValueModel(nn.Module):
return value_dist
class RewardModel(nn.Module):
def __init__(self, state_size, hidden_size):
super().__init__()
self.reward_model = nn.Sequential(
nn.Linear(state_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1)
)
def forward(self, state):
reward = self.reward_model(state).squeeze(dim=1)
return reward
class TransitionModel(nn.Module):
def __init__(self, state_size, hidden_size, action_size, history_size):
super().__init__()
@ -195,7 +211,6 @@ class TransitionModel(nn.Module):
return prior
def stack_states(self, states, dim=0):
s = dict(
mean = torch.stack([state['mean'] for state in states], dim=dim),
std = torch.stack([state['std'] for state in states], dim=dim),

View File

@ -11,7 +11,7 @@ import tqdm
import wandb
import utils
from utils import ReplayBuffer, make_env, save_image
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample, Actor, ValueModel, RewardModel
from logger import Logger
from video import VideoRecorder
from dmc2gym.wrappers import set_global_var
@ -176,6 +176,27 @@ class DPI:
history_size=self.args.history_size, # 128
)
self.action_model = Actor(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256,
action_size=self.env.action_space.shape[0], # 6
)
self.value_model = ValueModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
)
self.target_value_model = ValueModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
)
self.reward_model = RewardModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
)
# model parameters
self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_encoder_momentum.parameters()) + \
list(self.obs_decoder.parameters()) + list(self.transition_model.parameters())
@ -282,7 +303,7 @@ class DPI:
imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"], self.action, self.history, imagine_horizon)
print(imagine_horizon)
#exit()
#print(total_ub_loss, total_encoder_loss)