Compare commits

..

2 Commits

2 changed files with 43 additions and 7 deletions

View File

@ -93,7 +93,7 @@ class ObservationDecoder(nn.Module):
return out_dist return out_dist
class ActionDecoder(nn.Module): class Actor(nn.Module):
def __init__(self, state_size, hidden_size, action_size, num_layers=5): def __init__(self, state_size, hidden_size, action_size, num_layers=5):
super().__init__() super().__init__()
self.state_size = state_size self.state_size = state_size
@ -151,8 +151,24 @@ class ValueModel(nn.Module):
value = self.value_model(state) value = self.value_model(state)
value_dist = torch.distributions.independent.Independent(torch.distributions.Normal(value, 1), 1) value_dist = torch.distributions.independent.Independent(torch.distributions.Normal(value, 1), 1)
return value_dist return value_dist
class RewardModel(nn.Module):
def __init__(self, state_size, hidden_size):
super().__init__()
self.reward_model = nn.Sequential(
nn.Linear(state_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1)
)
def forward(self, state):
reward = self.reward_model(state).squeeze(dim=1)
return reward
class TransitionModel(nn.Module): class TransitionModel(nn.Module):
def __init__(self, state_size, hidden_size, action_size, history_size): def __init__(self, state_size, hidden_size, action_size, history_size):
super().__init__() super().__init__()
@ -194,8 +210,7 @@ class TransitionModel(nn.Module):
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist} prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist}
return prior return prior
def stack_states(self, states, dim=0): def stack_states(self, states, dim=0):
s = dict( s = dict(
mean = torch.stack([state['mean'] for state in states], dim=dim), mean = torch.stack([state['mean'] for state in states], dim=dim),
std = torch.stack([state['std'] for state in states], dim=dim), std = torch.stack([state['std'] for state in states], dim=dim),

View File

@ -11,7 +11,7 @@ import tqdm
import wandb import wandb
import utils import utils
from utils import ReplayBuffer, make_env, save_image from utils import ReplayBuffer, make_env, save_image
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample, Actor, ValueModel, RewardModel
from logger import Logger from logger import Logger
from video import VideoRecorder from video import VideoRecorder
from dmc2gym.wrappers import set_global_var from dmc2gym.wrappers import set_global_var
@ -175,6 +175,27 @@ class DPI:
action_size=self.env.action_space.shape[0], # 6 action_size=self.env.action_space.shape[0], # 6
history_size=self.args.history_size, # 128 history_size=self.args.history_size, # 128
) )
self.action_model = Actor(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256,
action_size=self.env.action_space.shape[0], # 6
)
self.value_model = ValueModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
)
self.target_value_model = ValueModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
)
self.reward_model = RewardModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
)
# model parameters # model parameters
self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_encoder_momentum.parameters()) + \ self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_encoder_momentum.parameters()) + \
@ -282,7 +303,7 @@ class DPI:
imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i) imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"], self.action, self.history, imagine_horizon) imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"], self.action, self.history, imagine_horizon)
print(imagine_horizon)
#exit() #exit()
#print(total_ub_loss, total_encoder_loss) #print(total_ub_loss, total_encoder_loss)