Adding action network

This commit is contained in:
Vedant Dave 2023-03-31 18:00:07 +02:00
parent 13765c2f9e
commit 4e1ef89924

View File

@ -48,6 +48,7 @@ def parse_args():
parser.add_argument('--state_size', default=256, type=int) parser.add_argument('--state_size', default=256, type=int)
parser.add_argument('--hidden_size', default=128, type=int) parser.add_argument('--hidden_size', default=128, type=int)
parser.add_argument('--history_size', default=128, type=int) parser.add_argument('--history_size', default=128, type=int)
parser.add_argument('--num-units', type=int, default=200, help='num hidden units for reward/value/discount models')
parser.add_argument('--load_encoder', default=None, type=str) parser.add_argument('--load_encoder', default=None, type=str)
parser.add_argument('--imagination_horizon', default=15, type=str) parser.add_argument('--imagination_horizon', default=15, type=str)
# eval # eval
@ -197,7 +198,7 @@ class DPI:
if args.save_video: if args.save_video:
self.env.video.init(enabled=True) self.env.video.init(enabled=True)
self.env_clean.video.init(enabled=True) self.env_clean.video.init(enabled=True)
for i in range(self.args.episode_length): for i in range(self.args.episode_length):
action = self.env.action_space.sample() action = self.env.action_space.sample()
@ -258,7 +259,7 @@ class DPI:
self.actions, self.history, i) self.actions, self.history, i)
print(past_encoder_loss, past_latent_loss)
previous_information_loss = past_latent_loss previous_information_loss = past_latent_loss
previous_encoder_loss = past_encoder_loss previous_encoder_loss = past_encoder_loss