DB/run.py

288 lines
12 KiB
Python
Raw Normal View History

2021-10-09 00:33:47 +00:00
#!/usr/bin/env python
try:
from OpenGL import GLU
except:
print("no OpenGL.GLU")
import functools
import os.path as osp
from functools import partial
import os
import gym
import tensorflow as tf
from baselines import logger
from baselines.bench import Monitor
from baselines.common.atari_wrappers import NoopResetEnv, FrameStack
from mpi4py import MPI
2023-05-29 11:37:22 +00:00
from dynamic_bottleneck import DynamicBottleneck
2021-10-09 00:33:47 +00:00
from cnn_policy import CnnPolicy
from cppo_agent import PpoOptimizer
from utils import random_agent_ob_mean_std
from wrappers import MontezumaInfoWrapper, make_mario_env, make_robo_pong, make_robo_hockey, \
make_multi_pong, AddRandomStateToInfo, MaxAndSkipEnv, ProcessFrame84, ExtraTimeLimit, StickyActionEnv
import datetime
from wrappers import PixelNoiseWrapper, RandomBoxNoiseWrapper
import json
getsess = tf.get_default_session
def start_experiment(**args):
2023-05-29 11:37:22 +00:00
2021-10-09 00:33:47 +00:00
make_env = partial(make_env_all_params, add_monitor=True, args=args)
trainer = Trainer(make_env=make_env,
num_timesteps=args['num_timesteps'], hps=args,
envs_per_process=args['envs_per_process'])
log, tf_sess, saver, logger_dir = get_experiment_environment(**args)
with log, tf_sess:
logdir = logger.get_dir()
print("results will be saved to ", logdir)
trainer.train(saver, logger_dir)
class Trainer(object):
def __init__(self, make_env, hps, num_timesteps, envs_per_process):
self.make_env = make_env
self.hps = hps
self.envs_per_process = envs_per_process
self.num_timesteps = num_timesteps
self._set_env_vars()
self.policy = CnnPolicy(scope='pol',
ob_space=self.ob_space,
ac_space=self.ac_space,
hidsize=512,
feat_dim=512,
ob_mean=self.ob_mean,
ob_std=self.ob_std,
layernormalize=False,
nl=tf.nn.leaky_relu)
self.dynamic_bottleneck = DynamicBottleneck(
policy=self.policy, feat_dim=512, tau=hps['momentum_tau'], loss_kl_weight=hps['loss_kl_weight'],
loss_nce_weight=hps['loss_nce_weight'], loss_l2_weight=hps['loss_l2_weight'], aug=hps['aug'])
self.agent = PpoOptimizer(
scope='ppo',
ob_space=self.ob_space,
ac_space=self.ac_space,
stochpol=self.policy,
use_news=hps['use_news'],
gamma=hps['gamma'],
lam=hps["lambda"],
nepochs=hps['nepochs'],
nminibatches=hps['nminibatches'],
lr=hps['lr'],
cliprange=0.1,
nsteps_per_seg=hps['nsteps_per_seg'],
nsegs_per_env=hps['nsegs_per_env'],
ent_coef=hps['ent_coeff'],
normrew=hps['norm_rew'],
normadv=hps['norm_adv'],
ext_coeff=hps['ext_coeff'],
int_coeff=hps['int_coeff'],
dynamic_bottleneck=self.dynamic_bottleneck
)
self.agent.to_report['db'] = tf.reduce_mean(self.dynamic_bottleneck.loss)
self.agent.total_loss += self.agent.to_report['db']
self.agent.db_loss = tf.reduce_mean(self.dynamic_bottleneck.loss)
self.agent.to_report['feat_var'] = tf.reduce_mean(tf.nn.moments(self.dynamic_bottleneck.features, [0, 1])[1])
def _set_env_vars(self):
env = self.make_env(0, add_monitor=False)
# ob_space.shape=(84, 84, 4) ac_space.shape=Discrete(4)
self.ob_space, self.ac_space = env.observation_space, env.action_space
self.ob_mean, self.ob_std = random_agent_ob_mean_std(env)
del env
self.envs = [functools.partial(self.make_env, i) for i in range(self.envs_per_process)]
def train(self, saver, logger_dir):
self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamic_bottleneck=self.dynamic_bottleneck)
previous_saved_tcount = 0
# add bai. initialize IB parameters
print("***Init Momentum Network in Dynamic-Bottleneck.")
getsess().run(self.dynamic_bottleneck.init_updates)
while True:
info = self.agent.step() #
if info['DB_loss_info']: # add bai. for debug
logger.logkvs(info['DB_loss_info'])
if info['update']:
logger.logkvs(info['update'])
logger.dumpkvs()
if self.hps["save_period"] and (int(self.agent.rollout.stats['tcount'] / self.hps["save_freq"]) > previous_saved_tcount):
previous_saved_tcount += 1
save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_"+str(previous_saved_tcount)+".ckpt"))
print("Periodically model saved in path:", save_path)
2023-05-29 11:37:22 +00:00
if self.agent.rollout.stats['tcount'] %10000: #self.agent.rollout.stats['tcount'] > self.num_timesteps:
2021-10-09 00:33:47 +00:00
save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_last.ckpt"))
print("Model saved in path:", save_path)
2023-05-29 11:37:22 +00:00
#break
2021-10-09 00:33:47 +00:00
self.agent.stop_interaction()
def make_env_all_params(rank, add_monitor, args):
if args["env_kind"] == 'atari':
env = gym.make(args['env'])
assert 'NoFrameskip' in env.spec.id
if args["stickyAtari"]: #
env._max_episode_steps = args['max_episode_steps'] * 4
env = StickyActionEnv(env)
else:
env = NoopResetEnv(env, noop_max=args['noop_max'])
env = MaxAndSkipEnv(env, skip=4) #
if args['pixelNoise']: # add pixel noise
env = PixelNoiseWrapper(env)
if args['randomBoxNoise']:
env = RandomBoxNoiseWrapper(env)
env = ProcessFrame84(env, crop=False) #
env = FrameStack(env, 4) #
# env = ExtraTimeLimit(env, args['max_episode_steps'])
if not args["stickyAtari"]:
env = ExtraTimeLimit(env, args['max_episode_steps']) #
if 'Montezuma' in args['env']: #
env = MontezumaInfoWrapper(env)
env = AddRandomStateToInfo(env)
elif args["env_kind"] == 'mario': #
env = make_mario_env()
elif args["env_kind"] == "retro_multi": #
env = make_multi_pong()
elif args["env_kind"] == 'robopong':
if args["env"] == "pong":
env = make_robo_pong()
elif args["env"] == "hockey":
env = make_robo_hockey()
if add_monitor:
env = Monitor(env, osp.join(logger.get_dir(), '%.2i' % rank))
return env
def get_experiment_environment(**args):
from utils import setup_mpi_gpus, setup_tensorflow_session
from baselines.common import set_global_seeds
from gym.utils.seeding import hash_seed
process_seed = args["seed"] + 1000 * MPI.COMM_WORLD.Get_rank()
process_seed = hash_seed(process_seed, max_bytes=4)
set_global_seeds(process_seed)
setup_mpi_gpus()
# log dir name
logger_dir = './logs/' + args["env"].replace("NoFrameskip-v4", "")
# logger_dir += "-KLloss-"+str(args["loss_kl_weight"])
# logger_dir += "-NCEloss-" + str(args["loss_nce_weight"])
# logger_dir += "-L2loss-" + str(args["loss_l2_weight"])
if args['pixelNoise'] is True:
logger_dir += "-pixelNoise"
if args['randomBoxNoise'] is True:
logger_dir += "-randomBoxNoise"
if args['stickyAtari'] is True:
logger_dir += "-stickyAtari"
if args["comments"] != "":
logger_dir += '-' + args["comments"]
logger_dir += datetime.datetime.now().strftime("-%m-%d-%H-%M-%S")
# write config
logger.configure(dir=logger_dir)
with open(os.path.join(logger_dir, 'parameters.txt'), 'w') as f:
f.write("\n".join([str(x[0]) + ": " + str(x[1]) for x in args.items()]))
logger_context = logger.scoped_configure(
dir=logger_dir,
format_strs=['stdout', 'log', 'csv'] if MPI.COMM_WORLD.Get_rank() == 0 else ['log'])
tf_context = setup_tensorflow_session()
# saver
saver = tf.train.Saver()
return logger_context, tf_context, saver, logger_dir
def add_environments_params(parser):
parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4', type=str)
parser.add_argument('--max-episode-steps', help='maximum number of timesteps for episode', default=4500, type=int)
parser.add_argument('--env_kind', type=str, default="atari")
parser.add_argument('--noop_max', type=int, default=30)
parser.add_argument('--stickyAtari', action='store_true', default=False)
parser.add_argument('--pixelNoise', action='store_true', default=False)
parser.add_argument('--randomBoxNoise', action='store_true', default=False)
def add_optimization_params(parser):
parser.add_argument('--lambda', type=float, default=0.95)
parser.add_argument('--gamma', type=float, default=0.99) # lambda, gamma 用于计算 GAE advantage
parser.add_argument('--nminibatches', type=int, default=8)
parser.add_argument('--norm_adv', type=int, default=1) #
parser.add_argument('--norm_rew', type=int, default=1) #
parser.add_argument('--lr', type=float, default=1e-4) #
parser.add_argument('--ent_coeff', type=float, default=0.001) #
parser.add_argument('--nepochs', type=int, default=3) #
parser.add_argument('--num_timesteps', type=int, default=int(1e8))
parser.add_argument('--save_period', action='store_true', default=False) # 1e7
parser.add_argument('--save_freq', type=int, default=int(1e7)) # 1e7
# Parameters of Dynamic-Bottleneck
parser.add_argument('--loss_kl_weight', type=float, default=0.1) # KL loss weight
parser.add_argument('--loss_l2_weight', type=float, default=0.1) # l2 loss weight
parser.add_argument('--loss_nce_weight', type=float, default=0.01) # nce loss weight
parser.add_argument('--momentum_tau', type=float, default=0.001) # momentum tau
parser.add_argument('--aug', action='store_true', default=False) # data augmentation (bottleneck)
parser.add_argument('--comments', type=str, default="")
def add_rollout_params(parser):
parser.add_argument('--nsteps_per_seg', type=int, default=128)
parser.add_argument('--nsegs_per_env', type=int, default=1)
parser.add_argument('--envs_per_process', type=int, default=128)
parser.add_argument('--nlumps', type=int, default=1)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
add_environments_params(parser)
add_optimization_params(parser)
add_rollout_params(parser)
parser.add_argument('--exp_name', type=str, default='')
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--dyn_from_pixels', type=int, default=0)
parser.add_argument('--use_news', type=int, default=0)
parser.add_argument('--ext_coeff', type=float, default=0.)
parser.add_argument('--int_coeff', type=float, default=1.)
parser.add_argument('--layernorm', type=int, default=0)
args = parser.parse_args()
# load paramets
with open("para.json") as f:
d = json.load(f)
env_name_para = args.env.replace("NoFrameskip-v4", "")
if env_name_para not in list(d["standard"].keys()):
env_name_para = "other"
if args.pixelNoise is True:
print("pixel noise")
args.loss_kl_weight = d["pixelNoise"][env_name_para]["kl"]
args.loss_nce_weight = d["pixelNoise"][env_name_para]["nce"]
elif args.randomBoxNoise is True:
print("random box noise")
args.loss_kl_weight = d["randomBox"][env_name_para]["kl"]
args.loss_nce_weight = d["randomBox"][env_name_para]["nce"]
elif args.stickyAtari is True:
print("sticky noise")
args.loss_kl_weight = d["stickyAtari"][env_name_para]["kl"]
args.loss_nce_weight = d["stickyAtari"][env_name_para]["nce"]
else:
print("standard atari")
args.loss_kl_weight = d["standard"][env_name_para]["kl"]
args.loss_nce_weight = d["standard"][env_name_para]["nce"]
print("env_name:", env_name_para, "kl:", args.loss_kl_weight, ", nce:", args.loss_nce_weight)
start_experiment(**args.__dict__)