DB/run.py
2023-05-29 17:11:26 +02:00

311 lines
13 KiB
Python

#!/usr/bin/env python
try:
from OpenGL import GLU
except:
print("no OpenGL.GLU")
import functools
import os.path as osp
from functools import partial
import os
import gym
import dmc2gym
import utils
import tensorflow as tf
from baselines import logger
from baselines.bench import Monitor
from baselines.common.atari_wrappers import NoopResetEnv, FrameStack
from mpi4py import MPI
from dynamic_bottleneck import DynamicBottleneck
from cnn_policy import CnnPolicy
from cppo_agent import PpoOptimizer
from utils import random_agent_ob_mean_std
from wrappers import MontezumaInfoWrapper, make_mario_env, make_robo_pong, make_robo_hockey, \
make_multi_pong, AddRandomStateToInfo, MaxAndSkipEnv, ProcessFrame84, ExtraTimeLimit, StickyActionEnv
import datetime
from wrappers import PixelNoiseWrapper, RandomBoxNoiseWrapper
import json
getsess = tf.get_default_session
def start_experiment(**args):
make_env = partial(make_env_all_params, add_monitor=True, args=args)
trainer = Trainer(make_env=make_env,
num_timesteps=args['num_timesteps'], hps=args,
envs_per_process=args['envs_per_process'])
log, tf_sess, saver, logger_dir = get_experiment_environment(**args)
with log, tf_sess:
logdir = logger.get_dir()
print("results will be saved to ", logdir)
trainer.train(saver, logger_dir)
class Trainer(object):
def __init__(self, make_env, hps, num_timesteps, envs_per_process):
self.make_env = make_env
self.hps = hps
self.envs_per_process = envs_per_process
self.num_timesteps = num_timesteps
self._set_env_vars()
self.policy = CnnPolicy(scope='pol',
ob_space=self.ob_space,
ac_space=self.ac_space,
hidsize=512,
feat_dim=512,
ob_mean=self.ob_mean,
ob_std=self.ob_std,
layernormalize=False,
nl=tf.nn.leaky_relu)
self.dynamic_bottleneck = DynamicBottleneck(
policy=self.policy, feat_dim=512, tau=hps['momentum_tau'], loss_kl_weight=hps['loss_kl_weight'],
loss_nce_weight=hps['loss_nce_weight'], loss_l2_weight=hps['loss_l2_weight'], aug=hps['aug'])
self.agent = PpoOptimizer(
scope='ppo',
ob_space=self.ob_space,
ac_space=self.ac_space,
stochpol=self.policy,
use_news=hps['use_news'],
gamma=hps['gamma'],
lam=hps["lambda"],
nepochs=hps['nepochs'],
nminibatches=hps['nminibatches'],
lr=hps['lr'],
cliprange=0.1,
nsteps_per_seg=hps['nsteps_per_seg'],
nsegs_per_env=hps['nsegs_per_env'],
ent_coef=hps['ent_coeff'],
normrew=hps['norm_rew'],
normadv=hps['norm_adv'],
ext_coeff=hps['ext_coeff'],
int_coeff=hps['int_coeff'],
dynamic_bottleneck=self.dynamic_bottleneck
)
self.agent.to_report['db'] = tf.reduce_mean(self.dynamic_bottleneck.loss)
self.agent.total_loss += self.agent.to_report['db']
self.agent.db_loss = tf.reduce_mean(self.dynamic_bottleneck.loss)
self.agent.to_report['feat_var'] = tf.reduce_mean(tf.nn.moments(self.dynamic_bottleneck.features, [0, 1])[1])
def _set_env_vars(self):
env = self.make_env(0, add_monitor=False)
# ob_space.shape=(84, 84, 4) ac_space.shape=Discrete(4)
self.ob_space, self.ac_space = env.observation_space, env.action_space
self.ob_mean, self.ob_std = random_agent_ob_mean_std(env)
del env
self.envs = [functools.partial(self.make_env, i) for i in range(self.envs_per_process)]
def train(self, saver, logger_dir):
self.agent.start_interaction(self.envs, nlump=self.hps['nlumps'], dynamic_bottleneck=self.dynamic_bottleneck)
previous_saved_tcount = 0
# add bai. initialize IB parameters
print("***Init Momentum Network in Dynamic-Bottleneck.")
getsess().run(self.dynamic_bottleneck.init_updates)
while True:
info = self.agent.step() #
if info['DB_loss_info']: # add bai. for debug
logger.logkvs(info['DB_loss_info'])
if info['update']:
logger.logkvs(info['update'])
logger.dumpkvs()
if self.hps["save_period"] and (int(self.agent.rollout.stats['tcount'] / self.hps["save_freq"]) > previous_saved_tcount):
previous_saved_tcount += 1
save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_"+str(previous_saved_tcount)+".ckpt"))
print("Periodically model saved in path:", save_path)
if self.agent.rollout.stats['tcount'] %10000: #self.agent.rollout.stats['tcount'] > self.num_timesteps:
save_path = saver.save(tf.get_default_session(), os.path.join(logger_dir, "model_last.ckpt"))
print("Model saved in path:", save_path)
#break
self.agent.stop_interaction()
def make_env_all_params(rank, add_monitor, args):
env = dmc2gym.make(
domain_name='cartpole',
task_name='swingup',
seed=args["seed"],
visualize_reward=False,
from_pixels='pixel',
height=84,
width=84,
frame_skip=4,
img_source=args["img_source"],
resource_files=args["resource_files"],
total_frames=args["total_frames"]
)
env.seed(args["seed"])
env = utils.FrameStack(env, k=4)
"""
if args["env_kind"] == 'atari':
env = gym.make(args['env'])
assert 'NoFrameskip' in env.spec.id
if args["stickyAtari"]: #
env._max_episode_steps = args['max_episode_steps'] * 4
env = StickyActionEnv(env)
else:
env = NoopResetEnv(env, noop_max=args['noop_max'])
env = MaxAndSkipEnv(env, skip=4) #
if args['pixelNoise']: # add pixel noise
env = PixelNoiseWrapper(env)
if args['randomBoxNoise']:
env = RandomBoxNoiseWrapper(env)
env = ProcessFrame84(env, crop=False) #
env = FrameStack(env, 4) #
# env = ExtraTimeLimit(env, args['max_episode_steps'])
if not args["stickyAtari"]:
env = ExtraTimeLimit(env, args['max_episode_steps']) #
if 'Montezuma' in args['env']: #
env = MontezumaInfoWrapper(env)
env = AddRandomStateToInfo(env)
elif args["env_kind"] == 'mario': #
env = make_mario_env()
elif args["env_kind"] == "retro_multi": #
env = make_multi_pong()
elif args["env_kind"] == 'robopong':
if args["env"] == "pong":
env = make_robo_pong()
elif args["env"] == "hockey":
env = make_robo_hockey()
"""
if add_monitor:
env = Monitor(env, osp.join(logger.get_dir(), '%.2i' % rank))
return env
def get_experiment_environment(**args):
from utils import setup_mpi_gpus, setup_tensorflow_session
from baselines.common import set_global_seeds
from gym.utils.seeding import hash_seed
process_seed = args["seed"] + 1000 * MPI.COMM_WORLD.Get_rank()
process_seed = hash_seed(process_seed, max_bytes=4)
set_global_seeds(process_seed)
setup_mpi_gpus()
# log dir name
logger_dir = './logs/' + args["env"].replace("NoFrameskip-v4", "")
# logger_dir += "-KLloss-"+str(args["loss_kl_weight"])
# logger_dir += "-NCEloss-" + str(args["loss_nce_weight"])
# logger_dir += "-L2loss-" + str(args["loss_l2_weight"])
if args['pixelNoise'] is True:
logger_dir += "-pixelNoise"
if args['randomBoxNoise'] is True:
logger_dir += "-randomBoxNoise"
if args['stickyAtari'] is True:
logger_dir += "-stickyAtari"
if args["comments"] != "":
logger_dir += '-' + args["comments"]
logger_dir += datetime.datetime.now().strftime("-%m-%d-%H-%M-%S")
# write config
logger.configure(dir=logger_dir)
with open(os.path.join(logger_dir, 'parameters.txt'), 'w') as f:
f.write("\n".join([str(x[0]) + ": " + str(x[1]) for x in args.items()]))
logger_context = logger.scoped_configure(
dir=logger_dir,
format_strs=['stdout', 'log', 'csv'] if MPI.COMM_WORLD.Get_rank() == 0 else ['log'])
tf_context = setup_tensorflow_session()
# saver
saver = tf.train.Saver()
return logger_context, tf_context, saver, logger_dir
def add_environments_params(parser):
parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4', type=str)
parser.add_argument('--max-episode-steps', help='maximum number of timesteps for episode', default=4500, type=int)
parser.add_argument('--env_kind', type=str, default="atari")
parser.add_argument('--noop_max', type=int, default=30)
parser.add_argument('--stickyAtari', action='store_true', default=False)
parser.add_argument('--pixelNoise', action='store_true', default=False)
parser.add_argument('--randomBoxNoise', action='store_true', default=False)
parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
parser.add_argument('--resource_files', type=str)
parser.add_argument('--total_frames', default=100, type=int)
def add_optimization_params(parser):
parser.add_argument('--lambda', type=float, default=0.95)
parser.add_argument('--gamma', type=float, default=0.99) # lambda, gamma 用于计算 GAE advantage
parser.add_argument('--nminibatches', type=int, default=8)
parser.add_argument('--norm_adv', type=int, default=1) #
parser.add_argument('--norm_rew', type=int, default=1) #
parser.add_argument('--lr', type=float, default=1e-4) #
parser.add_argument('--ent_coeff', type=float, default=0.001) #
parser.add_argument('--nepochs', type=int, default=3) #
parser.add_argument('--num_timesteps', type=int, default=int(1e8))
parser.add_argument('--save_period', action='store_true', default=False) # 1e7
parser.add_argument('--save_freq', type=int, default=int(1e7)) # 1e7
# Parameters of Dynamic-Bottleneck
parser.add_argument('--loss_kl_weight', type=float, default=0.1) # KL loss weight
parser.add_argument('--loss_l2_weight', type=float, default=0.1) # l2 loss weight
parser.add_argument('--loss_nce_weight', type=float, default=0.01) # nce loss weight
parser.add_argument('--momentum_tau', type=float, default=0.001) # momentum tau
parser.add_argument('--aug', action='store_true', default=False) # data augmentation (bottleneck)
parser.add_argument('--comments', type=str, default="")
def add_rollout_params(parser):
parser.add_argument('--nsteps_per_seg', type=int, default=128)
parser.add_argument('--nsegs_per_env', type=int, default=1)
parser.add_argument('--envs_per_process', type=int, default=128)
parser.add_argument('--nlumps', type=int, default=1)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
add_environments_params(parser)
add_optimization_params(parser)
add_rollout_params(parser)
parser.add_argument('--exp_name', type=str, default='')
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--dyn_from_pixels', type=int, default=0)
parser.add_argument('--use_news', type=int, default=0)
parser.add_argument('--ext_coeff', type=float, default=0.)
parser.add_argument('--int_coeff', type=float, default=1.)
parser.add_argument('--layernorm', type=int, default=0)
args = parser.parse_args()
# load paramets
with open("para.json") as f:
d = json.load(f)
env_name_para = args.env.replace("NoFrameskip-v4", "")
if env_name_para not in list(d["standard"].keys()):
env_name_para = "other"
if args.pixelNoise is True:
print("pixel noise")
args.loss_kl_weight = d["pixelNoise"][env_name_para]["kl"]
args.loss_nce_weight = d["pixelNoise"][env_name_para]["nce"]
elif args.randomBoxNoise is True:
print("random box noise")
args.loss_kl_weight = d["randomBox"][env_name_para]["kl"]
args.loss_nce_weight = d["randomBox"][env_name_para]["nce"]
elif args.stickyAtari is True:
print("sticky noise")
args.loss_kl_weight = d["stickyAtari"][env_name_para]["kl"]
args.loss_nce_weight = d["stickyAtari"][env_name_para]["nce"]
else:
print("standard atari")
args.loss_kl_weight = d["standard"][env_name_para]["kl"]
args.loss_nce_weight = d["standard"][env_name_para]["nce"]
print("env_name:", env_name_para, "kl:", args.loss_kl_weight, ", nce:", args.loss_nce_weight)
start_experiment(**args.__dict__)