from torch.utils.tensorboard import SummaryWriter from collections import defaultdict import json import os import shutil import torch import torchvision import numpy as np from termcolor import colored FORMAT_CONFIG = { 'rl': { 'train': [ ('episode', 'E', 'int'), ('step', 'S', 'int'), ('duration', 'D', 'time'), ('episode_reward', 'R', 'float'), ('batch_reward', 'BR', 'float'), ('actor_loss', 'ALOSS', 'float'), ('critic_loss', 'CLOSS', 'float'), ('ae_loss', 'RLOSS', 'float') ], 'eval': [('step', 'S', 'int'), ('episode_reward', 'ER', 'float')] } } class AverageMeter(object): def __init__(self): self._sum = 0 self._count = 0 def update(self, value, n=1): self._sum += value self._count += n def value(self): return self._sum / max(1, self._count) class MetersGroup(object): def __init__(self, file_name, formating): self._file_name = file_name if os.path.exists(file_name): os.remove(file_name) self._formating = formating self._meters = defaultdict(AverageMeter) def log(self, key, value, n=1): self._meters[key].update(value, n) def _prime_meters(self): data = dict() for key, meter in self._meters.items(): if key.startswith('train'): key = key[len('train') + 1:] else: key = key[len('eval') + 1:] key = key.replace('/', '_') data[key] = meter.value() return data def _dump_to_file(self, data): with open(self._file_name, 'a') as f: f.write(json.dumps(data) + '\n') def _format(self, key, value, ty): template = '%s: ' if ty == 'int': template += '%d' elif ty == 'float': template += '%.04f' elif ty == 'time': template += '%.01f s' else: raise 'invalid format type: %s' % ty return template % (key, value) def _dump_to_console(self, data, prefix): prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') pieces = ['{:5}'.format(prefix)] for key, disp_key, ty in self._formating: value = data.get(key, 0) pieces.append(self._format(disp_key, value, ty)) print('| %s' % (' | '.join(pieces))) def dump(self, step, prefix): if len(self._meters) == 0: return data = self._prime_meters() data['step'] = step self._dump_to_file(data) self._dump_to_console(data, prefix) self._meters.clear() class Logger(object): def __init__(self, log_dir, use_tb=True, config='rl'): self._log_dir = log_dir if use_tb: tb_dir = os.path.join(log_dir, 'tb') if os.path.exists(tb_dir): shutil.rmtree(tb_dir) self._sw = SummaryWriter(tb_dir) else: self._sw = None self._train_mg = MetersGroup( os.path.join(log_dir, 'train.log'), formating=FORMAT_CONFIG[config]['train'] ) self._eval_mg = MetersGroup( os.path.join(log_dir, 'eval.log'), formating=FORMAT_CONFIG[config]['eval'] ) def _try_sw_log(self, key, value, step): if self._sw is not None: self._sw.add_scalar(key, value, step) def _try_sw_log_image(self, key, image, step): if self._sw is not None: assert image.dim() == 3 grid = torchvision.utils.make_grid(image.unsqueeze(1)) self._sw.add_image(key, grid, step) def _try_sw_log_video(self, key, frames, step): if self._sw is not None: frames = torch.from_numpy(np.array(frames)) frames = frames.unsqueeze(0) self._sw.add_video(key, frames, step, fps=30) def _try_sw_log_histogram(self, key, histogram, step): if self._sw is not None: self._sw.add_histogram(key, histogram, step) def log(self, key, value, step, n=1): assert key.startswith('train') or key.startswith('eval') if type(value) == torch.Tensor: value = value.item() self._try_sw_log(key, value / n, step) mg = self._train_mg if key.startswith('train') else self._eval_mg mg.log(key, value, n) def log_param(self, key, param, step): self.log_histogram(key + '_w', param.weight.data, step) if hasattr(param.weight, 'grad') and param.weight.grad is not None: self.log_histogram(key + '_w_g', param.weight.grad.data, step) if hasattr(param, 'bias'): self.log_histogram(key + '_b', param.bias.data, step) if hasattr(param.bias, 'grad') and param.bias.grad is not None: self.log_histogram(key + '_b_g', param.bias.grad.data, step) def log_image(self, key, image, step): assert key.startswith('train') or key.startswith('eval') self._try_sw_log_image(key, image, step) def log_video(self, key, frames, step): assert key.startswith('train') or key.startswith('eval') self._try_sw_log_video(key, frames, step) def log_histogram(self, key, histogram, step): assert key.startswith('train') or key.startswith('eval') self._try_sw_log_histogram(key, histogram, step) def dump(self, step): self._train_mg.dump(step, 'train') self._eval_mg.dump(step, 'eval')