Compare commits
No commits in common. "8fd56ba94ded48373ce8af0c6b4b243c16642111" and "c4283ced6fa9699ac94fe598a6dac935efc0d7f4" have entirely different histories.
8fd56ba94d
...
c4283ced6f
@ -93,7 +93,7 @@ class ObservationDecoder(nn.Module):
|
|||||||
return out_dist
|
return out_dist
|
||||||
|
|
||||||
|
|
||||||
class Actor(nn.Module):
|
class ActionDecoder(nn.Module):
|
||||||
def __init__(self, state_size, hidden_size, action_size, num_layers=5):
|
def __init__(self, state_size, hidden_size, action_size, num_layers=5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.state_size = state_size
|
self.state_size = state_size
|
||||||
@ -151,24 +151,8 @@ class ValueModel(nn.Module):
|
|||||||
value = self.value_model(state)
|
value = self.value_model(state)
|
||||||
value_dist = torch.distributions.independent.Independent(torch.distributions.Normal(value, 1), 1)
|
value_dist = torch.distributions.independent.Independent(torch.distributions.Normal(value, 1), 1)
|
||||||
return value_dist
|
return value_dist
|
||||||
|
|
||||||
|
|
||||||
class RewardModel(nn.Module):
|
|
||||||
def __init__(self, state_size, hidden_size):
|
|
||||||
super().__init__()
|
|
||||||
self.reward_model = nn.Sequential(
|
|
||||||
nn.Linear(state_size, hidden_size),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_size, hidden_size),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_size, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, state):
|
|
||||||
reward = self.reward_model(state).squeeze(dim=1)
|
|
||||||
return reward
|
|
||||||
|
|
||||||
|
|
||||||
class TransitionModel(nn.Module):
|
class TransitionModel(nn.Module):
|
||||||
def __init__(self, state_size, hidden_size, action_size, history_size):
|
def __init__(self, state_size, hidden_size, action_size, history_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -210,7 +194,8 @@ class TransitionModel(nn.Module):
|
|||||||
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist}
|
prior = {"mean": state_prior_mean, "std": state_prior_std, "sample": sample_state_prior, "history": history, "distribution": state_prior_dist}
|
||||||
return prior
|
return prior
|
||||||
|
|
||||||
def stack_states(self, states, dim=0):
|
def stack_states(self, states, dim=0):
|
||||||
|
|
||||||
s = dict(
|
s = dict(
|
||||||
mean = torch.stack([state['mean'] for state in states], dim=dim),
|
mean = torch.stack([state['mean'] for state in states], dim=dim),
|
||||||
std = torch.stack([state['std'] for state in states], dim=dim),
|
std = torch.stack([state['std'] for state in states], dim=dim),
|
||||||
|
25
DPI/train.py
25
DPI/train.py
@ -11,7 +11,7 @@ import tqdm
|
|||||||
import wandb
|
import wandb
|
||||||
import utils
|
import utils
|
||||||
from utils import ReplayBuffer, make_env, save_image
|
from utils import ReplayBuffer, make_env, save_image
|
||||||
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample, Actor, ValueModel, RewardModel
|
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample
|
||||||
from logger import Logger
|
from logger import Logger
|
||||||
from video import VideoRecorder
|
from video import VideoRecorder
|
||||||
from dmc2gym.wrappers import set_global_var
|
from dmc2gym.wrappers import set_global_var
|
||||||
@ -175,27 +175,6 @@ class DPI:
|
|||||||
action_size=self.env.action_space.shape[0], # 6
|
action_size=self.env.action_space.shape[0], # 6
|
||||||
history_size=self.args.history_size, # 128
|
history_size=self.args.history_size, # 128
|
||||||
)
|
)
|
||||||
|
|
||||||
self.action_model = Actor(
|
|
||||||
state_size=self.args.state_size, # 128
|
|
||||||
hidden_size=self.args.hidden_size, # 256,
|
|
||||||
action_size=self.env.action_space.shape[0], # 6
|
|
||||||
)
|
|
||||||
|
|
||||||
self.value_model = ValueModel(
|
|
||||||
state_size=self.args.state_size, # 128
|
|
||||||
hidden_size=self.args.hidden_size, # 256
|
|
||||||
)
|
|
||||||
|
|
||||||
self.target_value_model = ValueModel(
|
|
||||||
state_size=self.args.state_size, # 128
|
|
||||||
hidden_size=self.args.hidden_size, # 256
|
|
||||||
)
|
|
||||||
|
|
||||||
self.reward_model = RewardModel(
|
|
||||||
state_size=self.args.state_size, # 128
|
|
||||||
hidden_size=self.args.hidden_size, # 256
|
|
||||||
)
|
|
||||||
|
|
||||||
# model parameters
|
# model parameters
|
||||||
self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_encoder_momentum.parameters()) + \
|
self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_encoder_momentum.parameters()) + \
|
||||||
@ -303,7 +282,7 @@ class DPI:
|
|||||||
|
|
||||||
imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
|
imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
|
||||||
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"], self.action, self.history, imagine_horizon)
|
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"], self.action, self.history, imagine_horizon)
|
||||||
|
print(imagine_horizon)
|
||||||
#exit()
|
#exit()
|
||||||
|
|
||||||
#print(total_ub_loss, total_encoder_loss)
|
#print(total_ub_loss, total_encoder_loss)
|
||||||
|
Loading…
Reference in New Issue
Block a user