Compare commits

..

No commits in common. "8fd56ba94ded48373ce8af0c6b4b243c16642111" and "c4283ced6fa9699ac94fe598a6dac935efc0d7f4" have entirely different histories.

2 changed files with 7 additions and 43 deletions

View File

@ -93,7 +93,7 @@ class ObservationDecoder(nn.Module):
return out_dist return out_dist
class Actor(nn.Module): class ActionDecoder(nn.Module):
def __init__(self, state_size, hidden_size, action_size, num_layers=5): def __init__(self, state_size, hidden_size, action_size, num_layers=5):
super().__init__() super().__init__()
self.state_size = state_size self.state_size = state_size
@ -151,24 +151,8 @@ class ValueModel(nn.Module):
value = self.value_model(state) value = self.value_model(state)
value_dist = torch.distributions.independent.Independent(torch.distributions.Normal(value, 1), 1) value_dist = torch.distributions.independent.Independent(torch.distributions.Normal(value, 1), 1)
return value_dist return value_dist
class RewardModel(nn.Module):
def __init__(self, state_size, hidden_size):
super().__init__()
self.reward_model = nn.Sequential(
nn.Linear(state_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1)
)
def forward(self, state):
reward = self.reward_model(state).squeeze(dim=1)
return reward
class TransitionModel(nn.Module): class TransitionModel(nn.Module):
def __init__(self, state_size, hidden_size, action_size, history_size): def __init__(self, state_size, hidden_size, action_size, history_size):
super().__init__() super().__init__()
@ -210,7 +194,8 @@ class TransitionModel(nn.Module):
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist} prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist}
return prior return prior
def stack_states(self, states, dim=0): def stack_states(self, states, dim=0):
s = dict( s = dict(
mean = torch.stack([state['mean'] for state in states], dim=dim), mean = torch.stack([state['mean'] for state in states], dim=dim),
std = torch.stack([state['std'] for state in states], dim=dim), std = torch.stack([state['std'] for state in states], dim=dim),

View File

@ -11,7 +11,7 @@ import tqdm
import wandb import wandb
import utils import utils
from utils import ReplayBuffer, make_env, save_image from utils import ReplayBuffer, make_env, save_image
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample, Actor, ValueModel, RewardModel from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample
from logger import Logger from logger import Logger
from video import VideoRecorder from video import VideoRecorder
from dmc2gym.wrappers import set_global_var from dmc2gym.wrappers import set_global_var
@ -175,27 +175,6 @@ class DPI:
action_size=self.env.action_space.shape[0], # 6 action_size=self.env.action_space.shape[0], # 6
history_size=self.args.history_size, # 128 history_size=self.args.history_size, # 128
) )
self.action_model = Actor(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256,
action_size=self.env.action_space.shape[0], # 6
)
self.value_model = ValueModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
)
self.target_value_model = ValueModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
)
self.reward_model = RewardModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
)
# model parameters # model parameters
self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_encoder_momentum.parameters()) + \ self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_encoder_momentum.parameters()) + \
@ -303,7 +282,7 @@ class DPI:
imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i) imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"], self.action, self.history, imagine_horizon) imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"], self.action, self.history, imagine_horizon)
print(imagine_horizon)
#exit() #exit()
#print(total_ub_loss, total_encoder_loss) #print(total_ub_loss, total_encoder_loss)