166 lines
5.5 KiB
Python
166 lines
5.5 KiB
Python
from torch.utils.tensorboard import SummaryWriter
|
|
from collections import defaultdict
|
|
import json
|
|
import os
|
|
import shutil
|
|
import torch
|
|
import torchvision
|
|
import numpy as np
|
|
from termcolor import colored
|
|
|
|
|
|
FORMAT_CONFIG = {
|
|
'rl': {
|
|
'train': [('episode', 'E', 'int'),
|
|
('step', 'S', 'int'),
|
|
('duration', 'D', 'time'),
|
|
('episode_reward', 'R', 'float'),
|
|
('batch_reward', 'BR', 'float'),
|
|
('actor_loss', 'ALOSS', 'float'),
|
|
('critic_loss', 'CLOSS', 'float'),
|
|
('ae_loss', 'RLOSS', 'float')],
|
|
'eval': [('step', 'S', 'int'),
|
|
('episode_reward', 'ER', 'float')]
|
|
}
|
|
}
|
|
|
|
|
|
class AverageMeter(object):
|
|
def __init__(self):
|
|
self._sum = 0
|
|
self._count = 0
|
|
|
|
def update(self, value, n=1):
|
|
self._sum += value
|
|
self._count += n
|
|
|
|
def value(self):
|
|
return self._sum / max(1, self._count)
|
|
|
|
|
|
class MetersGroup(object):
|
|
def __init__(self, file_name, formating):
|
|
self._file_name = file_name
|
|
if os.path.exists(file_name):
|
|
os.remove(file_name)
|
|
self._formating = formating
|
|
self._meters = defaultdict(AverageMeter)
|
|
|
|
def log(self, key, value, n=1):
|
|
self._meters[key].update(value, n)
|
|
|
|
def _prime_meters(self):
|
|
data = dict()
|
|
for key, meter in self._meters.items():
|
|
if key.startswith('train'):
|
|
key = key[len('train') + 1:]
|
|
else:
|
|
key = key[len('eval') + 1:]
|
|
key = key.replace('/', '_')
|
|
data[key] = meter.value()
|
|
return data
|
|
|
|
def _dump_to_file(self, data):
|
|
with open(self._file_name, 'a') as f:
|
|
f.write(json.dumps(data) + '\n')
|
|
|
|
def _format(self, key, value, ty):
|
|
template = '%s: '
|
|
if ty == 'int':
|
|
template += '%d'
|
|
elif ty == 'float':
|
|
template += '%.04f'
|
|
elif ty == 'time':
|
|
template += '%.01f s'
|
|
else:
|
|
raise 'invalid format type: %s' % ty
|
|
return template % (key, value)
|
|
|
|
def _dump_to_console(self, data, prefix):
|
|
prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
|
|
pieces = ['{:5}'.format(prefix)]
|
|
for key, disp_key, ty in self._formating:
|
|
value = data.get(key, 0)
|
|
pieces.append(self._format(disp_key, value, ty))
|
|
print('| %s' % (' | '.join(pieces)))
|
|
|
|
def dump(self, step, prefix):
|
|
if len(self._meters) == 0:
|
|
return
|
|
data = self._prime_meters()
|
|
data['step'] = step
|
|
self._dump_to_file(data)
|
|
self._dump_to_console(data, prefix)
|
|
self._meters.clear()
|
|
|
|
|
|
class Logger(object):
|
|
def __init__(self, log_dir, use_tb=True, config='rl'):
|
|
self._log_dir = log_dir
|
|
if use_tb:
|
|
tb_dir = os.path.join(log_dir, 'tb')
|
|
if os.path.exists(tb_dir):
|
|
shutil.rmtree(tb_dir)
|
|
self._sw = SummaryWriter(tb_dir)
|
|
else:
|
|
self._sw = None
|
|
self._train_mg = MetersGroup(
|
|
os.path.join(log_dir, 'train.log'),
|
|
formating=FORMAT_CONFIG[config]['train'])
|
|
self._eval_mg = MetersGroup(
|
|
os.path.join(log_dir, 'eval.log'),
|
|
formating=FORMAT_CONFIG[config]['eval'])
|
|
|
|
def _try_sw_log(self, key, value, step):
|
|
if self._sw is not None:
|
|
self._sw.add_scalar(key, value, step)
|
|
|
|
def _try_sw_log_image(self, key, image, step):
|
|
if self._sw is not None:
|
|
assert image.dim() == 3
|
|
grid = torchvision.utils.make_grid(image.unsqueeze(1))
|
|
self._sw.add_image(key, grid, step)
|
|
|
|
def _try_sw_log_video(self, key, frames, step):
|
|
if self._sw is not None:
|
|
frames = torch.from_numpy(np.array(frames))
|
|
frames = frames.unsqueeze(0)
|
|
self._sw.add_video(key, frames, step, fps=30)
|
|
|
|
def _try_sw_log_histogram(self, key, histogram, step):
|
|
if self._sw is not None:
|
|
self._sw.add_histogram(key, histogram, step)
|
|
|
|
def log(self, key, value, step, n=1):
|
|
assert key.startswith('train') or key.startswith('eval')
|
|
if type(value) == torch.Tensor:
|
|
value = value.item()
|
|
self._try_sw_log(key, value / n, step)
|
|
mg = self._train_mg if key.startswith('train') else self._eval_mg
|
|
mg.log(key, value, n)
|
|
|
|
def log_param(self, key, param, step):
|
|
self.log_histogram(key + '_w', param.weight.data, step)
|
|
if hasattr(param.weight, 'grad') and param.weight.grad is not None:
|
|
self.log_histogram(key + '_w_g', param.weight.grad.data, step)
|
|
if hasattr(param, 'bias'):
|
|
self.log_histogram(key + '_b', param.bias.data, step)
|
|
if hasattr(param.bias, 'grad') and param.bias.grad is not None:
|
|
self.log_histogram(key + '_b_g', param.bias.grad.data, step)
|
|
|
|
def log_image(self, key, image, step):
|
|
assert key.startswith('train') or key.startswith('eval')
|
|
self._try_sw_log_image(key, image, step)
|
|
|
|
def log_video(self, key, frames, step):
|
|
assert key.startswith('train') or key.startswith('eval')
|
|
self._try_sw_log_video(key, frames, step)
|
|
|
|
def log_histogram(self, key, histogram, step):
|
|
assert key.startswith('train') or key.startswith('eval')
|
|
self._try_sw_log_histogram(key, histogram, step)
|
|
|
|
def dump(self, step):
|
|
self._train_mg.dump(step, 'train')
|
|
self._eval_mg.dump(step, 'eval')
|