Adding Files
This commit is contained in:
parent
061057fa11
commit
96acac90da
@ -68,7 +68,7 @@ class DynamicBottleneck(object):
|
||||
# contrastive projection
|
||||
z_a = self.projection_head(rec_vec) # (None, 128)
|
||||
z_pos = tf.stop_gradient(self.projection_head_momentum(self.next_features)) # (None, 128)
|
||||
assert z_a.get_shape().as_list()[-1] == 128 and len(z_a.get_shape().as_list()) == 2
|
||||
assert z_a.get_shape().as_list()[-1] == 128 and len(z_a.get_shape().as_list()) == 2
|
||||
|
||||
# contrastive loss
|
||||
logits = self.contrastive_head([z_a, z_pos]) # (batch_size, batch_size)
|
||||
@ -129,12 +129,11 @@ class DynamicBottleneck(object):
|
||||
|
||||
def get_features(self, x, momentum=False): # x.shape=(None,None,84,84,4)
|
||||
x_has_timesteps = (x.get_shape().ndims == 5) # True
|
||||
if x_has_timesteps:
|
||||
if x_has_timesteps:
|
||||
sh = tf.shape(x)
|
||||
x = flatten_two_dims(x) # (None,84,84,4)
|
||||
|
||||
if self.aug:
|
||||
print(x.get_shape().as_list())
|
||||
x = tf.image.random_crop(x, size=[128*16, 80, 80, 4]) # (None,80,80,4)
|
||||
x = tf.pad(x, [[0, 0], [4, 4], [4, 4], [0, 0]], "SYMMETRIC") # (None,88,88,4)
|
||||
x = tf.image.random_crop(x, size=[128*16, 84, 84, 4]) # (None,84,84,4)
|
||||
@ -158,6 +157,7 @@ class DynamicBottleneck(object):
|
||||
n_chunks = 8
|
||||
n = ob.shape[0]
|
||||
chunk_size = n // n_chunks
|
||||
|
||||
assert n % n_chunks == 0
|
||||
sli = lambda i: slice(i * chunk_size, (i + 1) * chunk_size)
|
||||
|
||||
|
18
rollouts.py
18
rollouts.py
@ -10,7 +10,7 @@ class Rollout(object):
|
||||
def __init__(self, ob_space, ac_space, nenvs, nsteps_per_seg, nsegs_per_env, nlumps, envs, policy,
|
||||
int_rew_coeff, ext_rew_coeff, record_rollouts, dynamic_bottleneck): #, noisy_box, noisy_p):
|
||||
# int_rew_coeff=1.0, ext_rew_coeff=0.0, record_rollouts=True
|
||||
self.nenvs = nenvs # 128
|
||||
self.nenvs = nenvs # 128/64
|
||||
self.nsteps_per_seg = nsteps_per_seg # 128
|
||||
self.nsegs_per_env = nsegs_per_env # 1
|
||||
self.nsteps = self.nsteps_per_seg * self.nsegs_per_env # 128
|
||||
@ -102,6 +102,22 @@ class Rollout(object):
|
||||
if self.recorder is not None:
|
||||
self.recorder.record(timestep=self.step_count, lump=l, acs=acs, infos=infos, int_rew=self.int_rew[sli],
|
||||
ext_rew=prevrews, news=news)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
x = self.buf_obs[sli, t][5][:,:,0]
|
||||
img = Image.fromarray(x)
|
||||
img.save('image1.png')
|
||||
x = self.buf_obs[sli, t][5][:,:,1]
|
||||
img = Image.fromarray(x)
|
||||
img.save('image2.png')
|
||||
x = self.buf_obs[sli, t][5][:,:,2]
|
||||
img = Image.fromarray(x)
|
||||
img.save('image3.png')
|
||||
x = self.buf_obs[sli, t][5][:,:,3]
|
||||
img = Image.fromarray(x)
|
||||
img.save('image4.png')
|
||||
|
||||
|
||||
self.step_count += 1
|
||||
if s == self.nsteps_per_seg - 1: # nsteps_per_seg=128
|
||||
|
10
run.py
10
run.py
@ -14,7 +14,8 @@ from baselines.bench import Monitor
|
||||
from baselines.common.atari_wrappers import NoopResetEnv, FrameStack
|
||||
from mpi4py import MPI
|
||||
|
||||
from dynamic_bottleneck import DynamicBottleneck
|
||||
|
||||
from dynamic_bottleneck import DynamicBottleneck
|
||||
from cnn_policy import CnnPolicy
|
||||
from cppo_agent import PpoOptimizer
|
||||
from utils import random_agent_ob_mean_std
|
||||
@ -26,8 +27,8 @@ import json
|
||||
|
||||
getsess = tf.get_default_session
|
||||
|
||||
|
||||
def start_experiment(**args):
|
||||
|
||||
make_env = partial(make_env_all_params, add_monitor=True, args=args)
|
||||
|
||||
trainer = Trainer(make_env=make_env,
|
||||
@ -39,7 +40,6 @@ def start_experiment(**args):
|
||||
print("results will be saved to ", logdir)
|
||||
trainer.train(saver, logger_dir)
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(self, make_env, hps, num_timesteps, envs_per_process):
|
||||
self.make_env = make_env
|
||||
@ -118,10 +118,10 @@ class Trainer(object):
|
||||
previous_saved_tcount += 1
|
||||
save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_"+str(previous_saved_tcount)+".ckpt"))
|
||||
print("Periodically model saved in path:", save_path)
|
||||
if self.agent.rollout.stats['tcount'] > self.num_timesteps:
|
||||
if self.agent.rollout.stats['tcount'] %10000: #self.agent.rollout.stats['tcount'] > self.num_timesteps:
|
||||
save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_last.ckpt"))
|
||||
print("Model saved in path:", save_path)
|
||||
break
|
||||
#break
|
||||
|
||||
self.agent.stop_interaction()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user