diff --git a/dynamic_bottleneck.py b/dynamic_bottleneck.py index 25313f9..9481755 100644 --- a/dynamic_bottleneck.py +++ b/dynamic_bottleneck.py @@ -68,7 +68,7 @@ class DynamicBottleneck(object): # contrastive projection z_a = self.projection_head(rec_vec) # (None, 128) z_pos = tf.stop_gradient(self.projection_head_momentum(self.next_features)) # (None, 128) - assert z_a.get_shape().as_list()[-1] == 128 and len(z_a.get_shape().as_list()) == 2 + assert z_a.get_shape().as_list()[-1] == 128 and len(z_a.get_shape().as_list()) == 2 # contrastive loss logits = self.contrastive_head([z_a, z_pos]) # (batch_size, batch_size) @@ -129,12 +129,11 @@ class DynamicBottleneck(object): def get_features(self, x, momentum=False): # x.shape=(None,None,84,84,4) x_has_timesteps = (x.get_shape().ndims == 5) # True - if x_has_timesteps: + if x_has_timesteps: sh = tf.shape(x) x = flatten_two_dims(x) # (None,84,84,4) if self.aug: - print(x.get_shape().as_list()) x = tf.image.random_crop(x, size=[128*16, 80, 80, 4]) # (None,80,80,4) x = tf.pad(x, [[0, 0], [4, 4], [4, 4], [0, 0]], "SYMMETRIC") # (None,88,88,4) x = tf.image.random_crop(x, size=[128*16, 84, 84, 4]) # (None,84,84,4) @@ -158,6 +157,7 @@ class DynamicBottleneck(object): n_chunks = 8 n = ob.shape[0] chunk_size = n // n_chunks + assert n % n_chunks == 0 sli = lambda i: slice(i * chunk_size, (i + 1) * chunk_size) diff --git a/rollouts.py b/rollouts.py index 44dc6b5..bc776aa 100644 --- a/rollouts.py +++ b/rollouts.py @@ -10,7 +10,7 @@ class Rollout(object): def __init__(self, ob_space, ac_space, nenvs, nsteps_per_seg, nsegs_per_env, nlumps, envs, policy, int_rew_coeff, ext_rew_coeff, record_rollouts, dynamic_bottleneck): #, noisy_box, noisy_p): # int_rew_coeff=1.0, ext_rew_coeff=0.0, record_rollouts=True - self.nenvs = nenvs # 128 + self.nenvs = nenvs # 128/64 self.nsteps_per_seg = nsteps_per_seg # 128 self.nsegs_per_env = nsegs_per_env # 1 self.nsteps = self.nsteps_per_seg * self.nsegs_per_env # 128 @@ -102,6 +102,22 @@ class Rollout(object): if self.recorder is not None: self.recorder.record(timestep=self.step_count, lump=l, acs=acs, infos=infos, int_rew=self.int_rew[sli], ext_rew=prevrews, news=news) + + import matplotlib.pyplot as plt + from PIL import Image + x = self.buf_obs[sli, t][5][:,:,0] + img = Image.fromarray(x) + img.save('image1.png') + x = self.buf_obs[sli, t][5][:,:,1] + img = Image.fromarray(x) + img.save('image2.png') + x = self.buf_obs[sli, t][5][:,:,2] + img = Image.fromarray(x) + img.save('image3.png') + x = self.buf_obs[sli, t][5][:,:,3] + img = Image.fromarray(x) + img.save('image4.png') + self.step_count += 1 if s == self.nsteps_per_seg - 1: # nsteps_per_seg=128 diff --git a/run.py b/run.py index a6ee0ba..2e94f91 100644 --- a/run.py +++ b/run.py @@ -14,7 +14,8 @@ from baselines.bench import Monitor from baselines.common.atari_wrappers import NoopResetEnv, FrameStack from mpi4py import MPI -from dynamic_bottleneck import DynamicBottleneck + +from dynamic_bottleneck import DynamicBottleneck from cnn_policy import CnnPolicy from cppo_agent import PpoOptimizer from utils import random_agent_ob_mean_std @@ -26,8 +27,8 @@ import json getsess = tf.get_default_session - def start_experiment(**args): + make_env = partial(make_env_all_params, add_monitor=True, args=args) trainer = Trainer(make_env=make_env, @@ -39,7 +40,6 @@ def start_experiment(**args): print("results will be saved to ", logdir) trainer.train(saver, logger_dir) - class Trainer(object): def __init__(self, make_env, hps, num_timesteps, envs_per_process): self.make_env = make_env @@ -118,10 +118,10 @@ class Trainer(object): previous_saved_tcount += 1 save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_"+str(previous_saved_tcount)+".ckpt")) print("Periodically model saved in path:", save_path) - if self.agent.rollout.stats['tcount'] > self.num_timesteps: + if self.agent.rollout.stats['tcount'] %10000: #self.agent.rollout.stats['tcount'] > self.num_timesteps: save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_last.ckpt")) print("Model saved in path:", save_path) - break + #break self.agent.stop_interaction()