Adding Reward, Value and Target Value models

This commit is contained in:
Vedant Dave 2023-04-10 13:18:08 +02:00
parent c4283ced6f
commit 47090449d1

View File

@ -11,7 +11,7 @@ import tqdm
import wandb import wandb
import utils import utils
from utils import ReplayBuffer, make_env, save_image from utils import ReplayBuffer, make_env, save_image
from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample from models import ObservationEncoder, ObservationDecoder, TransitionModel, CLUBSample, Actor, ValueModel, RewardModel
from logger import Logger from logger import Logger
from video import VideoRecorder from video import VideoRecorder
from dmc2gym.wrappers import set_global_var from dmc2gym.wrappers import set_global_var
@ -176,6 +176,27 @@ class DPI:
history_size=self.args.history_size, # 128 history_size=self.args.history_size, # 128
) )
self.action_model = Actor(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256,
action_size=self.env.action_space.shape[0], # 6
)
self.value_model = ValueModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
)
self.target_value_model = ValueModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
)
self.reward_model = RewardModel(
state_size=self.args.state_size, # 128
hidden_size=self.args.hidden_size, # 256
)
# model parameters # model parameters
self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_encoder_momentum.parameters()) + \ self.model_parameters = list(self.obs_encoder.parameters()) + list(self.obs_encoder_momentum.parameters()) + \
list(self.obs_decoder.parameters()) + list(self.transition_model.parameters()) list(self.obs_decoder.parameters()) + list(self.transition_model.parameters())
@ -282,7 +303,7 @@ class DPI:
imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i) imagine_horizon = np.minimum(self.args.imagine_horizon, self.args.episode_length-1-i)
imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"], self.action, self.history, imagine_horizon) imagined_rollout = self.transition_model.imagine_rollout(self.current_states_dict["sample"], self.action, self.history, imagine_horizon)
print(imagine_horizon)
#exit() #exit()
#print(total_ub_loss, total_encoder_loss) #print(total_ub_loss, total_encoder_loss)