Adding Files
This commit is contained in:
parent
061057fa11
commit
96acac90da
@ -68,7 +68,7 @@ class DynamicBottleneck(object):
|
|||||||
# contrastive projection
|
# contrastive projection
|
||||||
z_a = self.projection_head(rec_vec) # (None, 128)
|
z_a = self.projection_head(rec_vec) # (None, 128)
|
||||||
z_pos = tf.stop_gradient(self.projection_head_momentum(self.next_features)) # (None, 128)
|
z_pos = tf.stop_gradient(self.projection_head_momentum(self.next_features)) # (None, 128)
|
||||||
assert z_a.get_shape().as_list()[-1] == 128 and len(z_a.get_shape().as_list()) == 2
|
assert z_a.get_shape().as_list()[-1] == 128 and len(z_a.get_shape().as_list()) == 2
|
||||||
|
|
||||||
# contrastive loss
|
# contrastive loss
|
||||||
logits = self.contrastive_head([z_a, z_pos]) # (batch_size, batch_size)
|
logits = self.contrastive_head([z_a, z_pos]) # (batch_size, batch_size)
|
||||||
@ -129,12 +129,11 @@ class DynamicBottleneck(object):
|
|||||||
|
|
||||||
def get_features(self, x, momentum=False): # x.shape=(None,None,84,84,4)
|
def get_features(self, x, momentum=False): # x.shape=(None,None,84,84,4)
|
||||||
x_has_timesteps = (x.get_shape().ndims == 5) # True
|
x_has_timesteps = (x.get_shape().ndims == 5) # True
|
||||||
if x_has_timesteps:
|
if x_has_timesteps:
|
||||||
sh = tf.shape(x)
|
sh = tf.shape(x)
|
||||||
x = flatten_two_dims(x) # (None,84,84,4)
|
x = flatten_two_dims(x) # (None,84,84,4)
|
||||||
|
|
||||||
if self.aug:
|
if self.aug:
|
||||||
print(x.get_shape().as_list())
|
|
||||||
x = tf.image.random_crop(x, size=[128*16, 80, 80, 4]) # (None,80,80,4)
|
x = tf.image.random_crop(x, size=[128*16, 80, 80, 4]) # (None,80,80,4)
|
||||||
x = tf.pad(x, [[0, 0], [4, 4], [4, 4], [0, 0]], "SYMMETRIC") # (None,88,88,4)
|
x = tf.pad(x, [[0, 0], [4, 4], [4, 4], [0, 0]], "SYMMETRIC") # (None,88,88,4)
|
||||||
x = tf.image.random_crop(x, size=[128*16, 84, 84, 4]) # (None,84,84,4)
|
x = tf.image.random_crop(x, size=[128*16, 84, 84, 4]) # (None,84,84,4)
|
||||||
@ -158,6 +157,7 @@ class DynamicBottleneck(object):
|
|||||||
n_chunks = 8
|
n_chunks = 8
|
||||||
n = ob.shape[0]
|
n = ob.shape[0]
|
||||||
chunk_size = n // n_chunks
|
chunk_size = n // n_chunks
|
||||||
|
|
||||||
assert n % n_chunks == 0
|
assert n % n_chunks == 0
|
||||||
sli = lambda i: slice(i * chunk_size, (i + 1) * chunk_size)
|
sli = lambda i: slice(i * chunk_size, (i + 1) * chunk_size)
|
||||||
|
|
||||||
|
18
rollouts.py
18
rollouts.py
@ -10,7 +10,7 @@ class Rollout(object):
|
|||||||
def __init__(self, ob_space, ac_space, nenvs, nsteps_per_seg, nsegs_per_env, nlumps, envs, policy,
|
def __init__(self, ob_space, ac_space, nenvs, nsteps_per_seg, nsegs_per_env, nlumps, envs, policy,
|
||||||
int_rew_coeff, ext_rew_coeff, record_rollouts, dynamic_bottleneck): #, noisy_box, noisy_p):
|
int_rew_coeff, ext_rew_coeff, record_rollouts, dynamic_bottleneck): #, noisy_box, noisy_p):
|
||||||
# int_rew_coeff=1.0, ext_rew_coeff=0.0, record_rollouts=True
|
# int_rew_coeff=1.0, ext_rew_coeff=0.0, record_rollouts=True
|
||||||
self.nenvs = nenvs # 128
|
self.nenvs = nenvs # 128/64
|
||||||
self.nsteps_per_seg = nsteps_per_seg # 128
|
self.nsteps_per_seg = nsteps_per_seg # 128
|
||||||
self.nsegs_per_env = nsegs_per_env # 1
|
self.nsegs_per_env = nsegs_per_env # 1
|
||||||
self.nsteps = self.nsteps_per_seg * self.nsegs_per_env # 128
|
self.nsteps = self.nsteps_per_seg * self.nsegs_per_env # 128
|
||||||
@ -102,6 +102,22 @@ class Rollout(object):
|
|||||||
if self.recorder is not None:
|
if self.recorder is not None:
|
||||||
self.recorder.record(timestep=self.step_count, lump=l, acs=acs, infos=infos, int_rew=self.int_rew[sli],
|
self.recorder.record(timestep=self.step_count, lump=l, acs=acs, infos=infos, int_rew=self.int_rew[sli],
|
||||||
ext_rew=prevrews, news=news)
|
ext_rew=prevrews, news=news)
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from PIL import Image
|
||||||
|
x = self.buf_obs[sli, t][5][:,:,0]
|
||||||
|
img = Image.fromarray(x)
|
||||||
|
img.save('image1.png')
|
||||||
|
x = self.buf_obs[sli, t][5][:,:,1]
|
||||||
|
img = Image.fromarray(x)
|
||||||
|
img.save('image2.png')
|
||||||
|
x = self.buf_obs[sli, t][5][:,:,2]
|
||||||
|
img = Image.fromarray(x)
|
||||||
|
img.save('image3.png')
|
||||||
|
x = self.buf_obs[sli, t][5][:,:,3]
|
||||||
|
img = Image.fromarray(x)
|
||||||
|
img.save('image4.png')
|
||||||
|
|
||||||
|
|
||||||
self.step_count += 1
|
self.step_count += 1
|
||||||
if s == self.nsteps_per_seg - 1: # nsteps_per_seg=128
|
if s == self.nsteps_per_seg - 1: # nsteps_per_seg=128
|
||||||
|
10
run.py
10
run.py
@ -14,7 +14,8 @@ from baselines.bench import Monitor
|
|||||||
from baselines.common.atari_wrappers import NoopResetEnv, FrameStack
|
from baselines.common.atari_wrappers import NoopResetEnv, FrameStack
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
|
|
||||||
from dynamic_bottleneck import DynamicBottleneck
|
|
||||||
|
from dynamic_bottleneck import DynamicBottleneck
|
||||||
from cnn_policy import CnnPolicy
|
from cnn_policy import CnnPolicy
|
||||||
from cppo_agent import PpoOptimizer
|
from cppo_agent import PpoOptimizer
|
||||||
from utils import random_agent_ob_mean_std
|
from utils import random_agent_ob_mean_std
|
||||||
@ -26,8 +27,8 @@ import json
|
|||||||
|
|
||||||
getsess = tf.get_default_session
|
getsess = tf.get_default_session
|
||||||
|
|
||||||
|
|
||||||
def start_experiment(**args):
|
def start_experiment(**args):
|
||||||
|
|
||||||
make_env = partial(make_env_all_params, add_monitor=True, args=args)
|
make_env = partial(make_env_all_params, add_monitor=True, args=args)
|
||||||
|
|
||||||
trainer = Trainer(make_env=make_env,
|
trainer = Trainer(make_env=make_env,
|
||||||
@ -39,7 +40,6 @@ def start_experiment(**args):
|
|||||||
print("results will be saved to ", logdir)
|
print("results will be saved to ", logdir)
|
||||||
trainer.train(saver, logger_dir)
|
trainer.train(saver, logger_dir)
|
||||||
|
|
||||||
|
|
||||||
class Trainer(object):
|
class Trainer(object):
|
||||||
def __init__(self, make_env, hps, num_timesteps, envs_per_process):
|
def __init__(self, make_env, hps, num_timesteps, envs_per_process):
|
||||||
self.make_env = make_env
|
self.make_env = make_env
|
||||||
@ -118,10 +118,10 @@ class Trainer(object):
|
|||||||
previous_saved_tcount += 1
|
previous_saved_tcount += 1
|
||||||
save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_"+str(previous_saved_tcount)+".ckpt"))
|
save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_"+str(previous_saved_tcount)+".ckpt"))
|
||||||
print("Periodically model saved in path:", save_path)
|
print("Periodically model saved in path:", save_path)
|
||||||
if self.agent.rollout.stats['tcount'] > self.num_timesteps:
|
if self.agent.rollout.stats['tcount'] %10000: #self.agent.rollout.stats['tcount'] > self.num_timesteps:
|
||||||
save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_last.ckpt"))
|
save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_last.ckpt"))
|
||||||
print("Model saved in path:", save_path)
|
print("Model saved in path:", save_path)
|
||||||
break
|
#break
|
||||||
|
|
||||||
self.agent.stop_interaction()
|
self.agent.stop_interaction()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user