diff --git a/DPI/train.py b/DPI/train.py index db8c027..0fe08a2 100644 --- a/DPI/train.py +++ b/DPI/train.py @@ -11,7 +11,7 @@ import tqdm import wandb import utils from utils import ReplayBuffer, make_env, save_image -from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample +from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample, Actor, ValueModel, RewardModel from logger import Logger from video import VideoRecorder from dmc2gym.wrappers import set_global_var @@ -175,6 +175,27 @@ class DPI: action_size=self.env.action_space.shape[0], # 6 history_size=self.args.history_size, # 128 ) + + self.action_model = Actor( + state_size=self.args.state_size, # 128 + hidden_size=self.args.hidden_size, # 256, + action_size=self.env.action_space.shape[0], # 6 + ) + + self.value_model = ValueModel( + state_size=self.args.state_size, # 128 + hidden_size=self.args.hidden_size, # 256 + ) + + self.target_value_model = ValueModel( + state_size=self.args.state_size, # 128 + hidden_size=self.args.hidden_size, # 256 + ) + + self.reward_model = RewardModel( + state_size=self.args.state_size, # 128 + hidden_size=self.args.hidden_size, # 256 + ) # model parameters self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_encoder_momentum.parameters()) + \ @@ -282,7 +303,7 @@ class DPI: imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i) imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"], self.action, self.history, imagine_horizon) - print(imagine_horizon) + #exit() #print(total_ub_loss, total_encoder_loss)