Adding Files

This commit is contained in:
Vedant Dave 2023-05-29 13:37:22 +02:00
parent 061057fa11
commit 96acac90da
3 changed files with 25 additions and 9 deletions

View File

@ -68,7 +68,7 @@ class DynamicBottleneck(object):
# contrastive projection # contrastive projection
z_a = self.projection_head(rec_vec) # (None, 128) z_a = self.projection_head(rec_vec) # (None, 128)
z_pos = tf.stop_gradient(self.projection_head_momentum(self.next_features)) # (None, 128) z_pos = tf.stop_gradient(self.projection_head_momentum(self.next_features)) # (None, 128)
assert z_a.get_shape().as_list()[-1] == 128 and len(z_a.get_shape().as_list()) == 2 assert z_a.get_shape().as_list()[-1] == 128 and len(z_a.get_shape().as_list()) == 2
# contrastive loss # contrastive loss
logits = self.contrastive_head([z_a, z_pos]) # (batch_size, batch_size) logits = self.contrastive_head([z_a, z_pos]) # (batch_size, batch_size)
@ -129,12 +129,11 @@ class DynamicBottleneck(object):
def get_features(self, x, momentum=False): # x.shape=(None,None,84,84,4) def get_features(self, x, momentum=False): # x.shape=(None,None,84,84,4)
x_has_timesteps = (x.get_shape().ndims == 5) # True x_has_timesteps = (x.get_shape().ndims == 5) # True
if x_has_timesteps: if x_has_timesteps:
sh = tf.shape(x) sh = tf.shape(x)
x = flatten_two_dims(x) # (None,84,84,4) x = flatten_two_dims(x) # (None,84,84,4)
if self.aug: if self.aug:
print(x.get_shape().as_list())
x = tf.image.random_crop(x, size=[128*16, 80, 80, 4]) # (None,80,80,4) x = tf.image.random_crop(x, size=[128*16, 80, 80, 4]) # (None,80,80,4)
x = tf.pad(x, [[0, 0], [4, 4], [4, 4], [0, 0]], "SYMMETRIC") # (None,88,88,4) x = tf.pad(x, [[0, 0], [4, 4], [4, 4], [0, 0]], "SYMMETRIC") # (None,88,88,4)
x = tf.image.random_crop(x, size=[128*16, 84, 84, 4]) # (None,84,84,4) x = tf.image.random_crop(x, size=[128*16, 84, 84, 4]) # (None,84,84,4)
@ -158,6 +157,7 @@ class DynamicBottleneck(object):
n_chunks = 8 n_chunks = 8
n = ob.shape[0] n = ob.shape[0]
chunk_size = n // n_chunks chunk_size = n // n_chunks
assert n % n_chunks == 0 assert n % n_chunks == 0
sli = lambda i: slice(i * chunk_size, (i + 1) * chunk_size) sli = lambda i: slice(i * chunk_size, (i + 1) * chunk_size)

View File

@ -10,7 +10,7 @@ class Rollout(object):
def __init__(self, ob_space, ac_space, nenvs, nsteps_per_seg, nsegs_per_env, nlumps, envs, policy, def __init__(self, ob_space, ac_space, nenvs, nsteps_per_seg, nsegs_per_env, nlumps, envs, policy,
int_rew_coeff, ext_rew_coeff, record_rollouts, dynamic_bottleneck): #, noisy_box, noisy_p): int_rew_coeff, ext_rew_coeff, record_rollouts, dynamic_bottleneck): #, noisy_box, noisy_p):
# int_rew_coeff=1.0, ext_rew_coeff=0.0, record_rollouts=True # int_rew_coeff=1.0, ext_rew_coeff=0.0, record_rollouts=True
self.nenvs = nenvs # 128 self.nenvs = nenvs # 128/64
self.nsteps_per_seg = nsteps_per_seg # 128 self.nsteps_per_seg = nsteps_per_seg # 128
self.nsegs_per_env = nsegs_per_env # 1 self.nsegs_per_env = nsegs_per_env # 1
self.nsteps = self.nsteps_per_seg * self.nsegs_per_env # 128 self.nsteps = self.nsteps_per_seg * self.nsegs_per_env # 128
@ -102,6 +102,22 @@ class Rollout(object):
if self.recorder is not None: if self.recorder is not None:
self.recorder.record(timestep=self.step_count, lump=l, acs=acs, infos=infos, int_rew=self.int_rew[sli], self.recorder.record(timestep=self.step_count, lump=l, acs=acs, infos=infos, int_rew=self.int_rew[sli],
ext_rew=prevrews, news=news) ext_rew=prevrews, news=news)
import matplotlib.pyplot as plt
from PIL import Image
x = self.buf_obs[sli, t][5][:,:,0]
img = Image.fromarray(x)
img.save('image1.png')
x = self.buf_obs[sli, t][5][:,:,1]
img = Image.fromarray(x)
img.save('image2.png')
x = self.buf_obs[sli, t][5][:,:,2]
img = Image.fromarray(x)
img.save('image3.png')
x = self.buf_obs[sli, t][5][:,:,3]
img = Image.fromarray(x)
img.save('image4.png')
self.step_count += 1 self.step_count += 1
if s == self.nsteps_per_seg - 1: # nsteps_per_seg=128 if s == self.nsteps_per_seg - 1: # nsteps_per_seg=128

10
run.py
View File

@ -14,7 +14,8 @@ from baselines.bench import Monitor
from baselines.common.atari_wrappers import NoopResetEnv, FrameStack from baselines.common.atari_wrappers import NoopResetEnv, FrameStack
from mpi4py import MPI from mpi4py import MPI
from dynamic_bottleneck import DynamicBottleneck
from dynamic_bottleneck import DynamicBottleneck
from cnn_policy import CnnPolicy from cnn_policy import CnnPolicy
from cppo_agent import PpoOptimizer from cppo_agent import PpoOptimizer
from utils import random_agent_ob_mean_std from utils import random_agent_ob_mean_std
@ -26,8 +27,8 @@ import json
getsess = tf.get_default_session getsess = tf.get_default_session
def start_experiment(**args): def start_experiment(**args):
make_env = partial(make_env_all_params, add_monitor=True, args=args) make_env = partial(make_env_all_params, add_monitor=True, args=args)
trainer = Trainer(make_env=make_env, trainer = Trainer(make_env=make_env,
@ -39,7 +40,6 @@ def start_experiment(**args):
print("results will be saved to ", logdir) print("results will be saved to ", logdir)
trainer.train(saver, logger_dir) trainer.train(saver, logger_dir)
class Trainer(object): class Trainer(object):
def __init__(self, make_env, hps, num_timesteps, envs_per_process): def __init__(self, make_env, hps, num_timesteps, envs_per_process):
self.make_env = make_env self.make_env = make_env
@ -118,10 +118,10 @@ class Trainer(object):
previous_saved_tcount += 1 previous_saved_tcount += 1
save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_"+str(previous_saved_tcount)+".ckpt")) save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_"+str(previous_saved_tcount)+".ckpt"))
print("Periodically model saved in path:", save_path) print("Periodically model saved in path:", save_path)
if self.agent.rollout.stats['tcount'] > self.num_timesteps: if self.agent.rollout.stats['tcount'] %10000: #self.agent.rollout.stats['tcount'] > self.num_timesteps:
save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_last.ckpt")) save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_last.ckpt"))
print("Model saved in path:", save_path) print("Model saved in path:", save_path)
break #break
self.agent.stop_interaction() self.agent.stop_interaction()