194 lines
8.9 KiB
Python
194 lines
8.9 KiB
Python
from collections import deque, defaultdict
|
|
|
|
import numpy as np
|
|
from mpi4py import MPI
|
|
# from utils import add_noise
|
|
from recorder import Recorder
|
|
|
|
|
|
class Rollout(object):
|
|
def __init__(self, ob_space, ac_space, nenvs, nsteps_per_seg, nsegs_per_env, nlumps, envs, policy,
|
|
int_rew_coeff, ext_rew_coeff, record_rollouts, dynamic_bottleneck): #, noisy_box, noisy_p):
|
|
# int_rew_coeff=1.0, ext_rew_coeff=0.0, record_rollouts=True
|
|
self.nenvs = nenvs # 128/64
|
|
self.nsteps_per_seg = nsteps_per_seg # 128
|
|
self.nsegs_per_env = nsegs_per_env # 1
|
|
self.nsteps = self.nsteps_per_seg * self.nsegs_per_env # 128
|
|
self.ob_space = ob_space # Box (84,84,4)
|
|
self.ac_space = ac_space # Discrete(4)
|
|
self.nlumps = nlumps # 1
|
|
self.lump_stride = nenvs // self.nlumps # 128
|
|
self.envs = envs
|
|
self.policy = policy
|
|
self.dynamic_bottleneck = dynamic_bottleneck # Dynamic Bottleneck
|
|
|
|
self.reward_fun = lambda ext_rew, int_rew: ext_rew_coeff * np.clip(ext_rew, -1., 1.) + int_rew_coeff * int_rew
|
|
|
|
self.buf_vpreds = np.empty((nenvs, self.nsteps), np.float32)
|
|
self.buf_nlps = np.empty((nenvs, self.nsteps), np.float32)
|
|
self.buf_rews = np.empty((nenvs, self.nsteps), np.float32)
|
|
self.buf_ext_rews = np.empty((nenvs, self.nsteps), np.float32)
|
|
self.buf_acs = np.empty((nenvs, self.nsteps, *self.ac_space.shape), self.ac_space.dtype)
|
|
self.buf_obs = np.empty((nenvs, self.nsteps, *self.ob_space.shape), self.ob_space.dtype)
|
|
self.buf_obs_last = np.empty((nenvs, self.nsegs_per_env, *self.ob_space.shape), np.float32)
|
|
|
|
self.buf_news = np.zeros((nenvs, self.nsteps), np.float32)
|
|
self.buf_new_last = self.buf_news[:, 0, ...].copy()
|
|
self.buf_vpred_last = self.buf_vpreds[:, 0, ...].copy()
|
|
|
|
self.env_results = [None] * self.nlumps
|
|
self.int_rew = np.zeros((nenvs,), np.float32)
|
|
|
|
self.recorder = Recorder(nenvs=self.nenvs, nlumps=self.nlumps) if record_rollouts else None
|
|
self.statlists = defaultdict(lambda: deque([], maxlen=100))
|
|
self.stats = defaultdict(float)
|
|
self.best_ext_ret = None
|
|
self.all_visited_rooms = []
|
|
self.all_scores = []
|
|
|
|
self.step_count = 0
|
|
|
|
# add bai. Noise box in observation
|
|
# self.noisy_box = noisy_box
|
|
# self.noisy_p = noisy_p
|
|
|
|
def collect_rollout(self):
|
|
self.ep_infos_new = []
|
|
for t in range(self.nsteps):
|
|
self.rollout_step()
|
|
self.calculate_reward()
|
|
self.update_info()
|
|
|
|
def calculate_reward(self): # Reward comes from Dynamic Bottleneck
|
|
db_rew = self.dynamic_bottleneck.calculate_db_reward(
|
|
ob=self.buf_obs, last_ob=self.buf_obs_last, acs=self.buf_acs)
|
|
self.buf_rews[:] = self.reward_fun(int_rew=db_rew, ext_rew=self.buf_ext_rews)
|
|
|
|
def rollout_step(self):
|
|
t = self.step_count % self.nsteps
|
|
s = t % self.nsteps_per_seg
|
|
for l in range(self.nlumps): # nclumps=1
|
|
obs, prevrews, news, infos = self.env_get(l)
|
|
# if t > 0:
|
|
# prev_feat = self.prev_feat[l]
|
|
# prev_acs = self.prev_acs[l]
|
|
for info in infos:
|
|
epinfo = info.get('episode', {})
|
|
mzepinfo = info.get('mz_episode', {})
|
|
retroepinfo = info.get('retro_episode', {})
|
|
epinfo.update(mzepinfo)
|
|
epinfo.update(retroepinfo)
|
|
if epinfo:
|
|
if "n_states_visited" in info:
|
|
epinfo["n_states_visited"] = info["n_states_visited"]
|
|
epinfo["states_visited"] = info["states_visited"]
|
|
self.ep_infos_new.append((self.step_count, epinfo))
|
|
|
|
# slice(0,128) lump_stride=128
|
|
sli = slice(l * self.lump_stride, (l + 1) * self.lump_stride)
|
|
|
|
acs, vpreds, nlps = self.policy.get_ac_value_nlp(obs)
|
|
self.env_step(l, acs)
|
|
|
|
# self.prev_feat[l] = dyn_feat
|
|
# self.prev_acs[l] = acs
|
|
self.buf_obs[sli, t] = obs # obs.shape=(128,84,84,4)
|
|
self.buf_news[sli, t] = news # shape=(128,) True/False
|
|
self.buf_vpreds[sli, t] = vpreds # shape=(128,)
|
|
self.buf_nlps[sli, t] = nlps # -log pi(a|s), shape=(128,)
|
|
self.buf_acs[sli, t] = acs # shape=(128,)
|
|
if t > 0:
|
|
self.buf_ext_rews[sli, t - 1] = prevrews # prevrews.shape=(128,)
|
|
if self.recorder is not None:
|
|
self.recorder.record(timestep=self.step_count, lump=l, acs=acs, infos=infos, int_rew=self.int_rew[sli],
|
|
ext_rew=prevrews, news=news)
|
|
|
|
import matplotlib.pyplot as plt
|
|
from PIL import Image
|
|
x = self.buf_obs[sli, t][5][:,:,0]
|
|
img = Image.fromarray(x)
|
|
img.save('image1.png')
|
|
x = self.buf_obs[sli, t][5][:,:,1]
|
|
img = Image.fromarray(x)
|
|
img.save('image2.png')
|
|
x = self.buf_obs[sli, t][5][:,:,2]
|
|
img = Image.fromarray(x)
|
|
img.save('image3.png')
|
|
x = self.buf_obs[sli, t][5][:,:,3]
|
|
img = Image.fromarray(x)
|
|
img.save('image4.png')
|
|
|
|
|
|
self.step_count += 1
|
|
if s == self.nsteps_per_seg - 1: # nsteps_per_seg=128
|
|
for l in range(self.nlumps): # nclumps=1
|
|
sli = slice(l * self.lump_stride, (l + 1) * self.lump_stride)
|
|
nextobs, ext_rews, nextnews, _ = self.env_get(l)
|
|
self.buf_obs_last[sli, t // self.nsteps_per_seg] = nextobs
|
|
if t == self.nsteps - 1: # t=127
|
|
self.buf_new_last[sli] = nextnews
|
|
self.buf_ext_rews[sli, t] = ext_rews #
|
|
_, self.buf_vpred_last[sli], _ = self.policy.get_ac_value_nlp(nextobs) #
|
|
|
|
def update_info(self):
|
|
all_ep_infos = MPI.COMM_WORLD.allgather(self.ep_infos_new)
|
|
all_ep_infos = sorted(sum(all_ep_infos, []), key=lambda x: x[0])
|
|
if all_ep_infos:
|
|
all_ep_infos = [i_[1] for i_ in all_ep_infos] # remove the step_count
|
|
keys_ = all_ep_infos[0].keys()
|
|
all_ep_infos = {k: [i[k] for i in all_ep_infos] for k in keys_}
|
|
# all_ep_infos: {'r': [0.0, 0.0, 0.0], 'l': [124, 125, 127], 't': [6.60745, 12.034875, 10.772788]}
|
|
|
|
self.statlists['eprew'].extend(all_ep_infos['r'])
|
|
self.stats['eprew_recent'] = np.mean(all_ep_infos['r'])
|
|
self.statlists['eplen'].extend(all_ep_infos['l'])
|
|
self.stats['epcount'] += len(all_ep_infos['l'])
|
|
self.stats['tcount'] += sum(all_ep_infos['l'])
|
|
if 'visited_rooms' in keys_:
|
|
# Montezuma specific logging.
|
|
self.stats['visited_rooms'] = sorted(list(set.union(*all_ep_infos['visited_rooms'])))
|
|
self.stats['pos_count'] = np.mean(all_ep_infos['pos_count'])
|
|
self.all_visited_rooms.extend(self.stats['visited_rooms'])
|
|
self.all_scores.extend(all_ep_infos["r"])
|
|
self.all_scores = sorted(list(set(self.all_scores)))
|
|
self.all_visited_rooms = sorted(list(set(self.all_visited_rooms)))
|
|
if MPI.COMM_WORLD.Get_rank() == 0:
|
|
print("All visited rooms")
|
|
print(self.all_visited_rooms)
|
|
print("All scores")
|
|
print(self.all_scores)
|
|
if 'levels' in keys_:
|
|
# Retro logging
|
|
temp = sorted(list(set.union(*all_ep_infos['levels'])))
|
|
self.all_visited_rooms.extend(temp)
|
|
self.all_visited_rooms = sorted(list(set(self.all_visited_rooms)))
|
|
if MPI.COMM_WORLD.Get_rank() == 0:
|
|
print("All visited levels")
|
|
print(self.all_visited_rooms)
|
|
|
|
current_max = np.max(all_ep_infos['r'])
|
|
else:
|
|
current_max = None
|
|
self.ep_infos_new = []
|
|
|
|
# best_ext_ret
|
|
if current_max is not None:
|
|
if (self.best_ext_ret is None) or (current_max > self.best_ext_ret):
|
|
self.best_ext_ret = current_max
|
|
self.current_max = current_max
|
|
|
|
def env_step(self, l, acs):
|
|
self.envs[l].step_async(acs)
|
|
self.env_results[l] = None
|
|
|
|
def env_get(self, l):
|
|
if self.step_count == 0:
|
|
ob = self.envs[l].reset()
|
|
out = self.env_results[l] = (ob, None, np.ones(self.lump_stride, bool), {})
|
|
else:
|
|
if self.env_results[l] is None:
|
|
out = self.env_results[l] = self.envs[l].step_wait()
|
|
else:
|
|
out = self.env_results[l]
|
|
return out
|