Adding logger

This commit is contained in:
Vedant Dave 2023-05-22 14:26:37 +02:00
parent f28db0ffe3
commit 4ac714c151

View File

@ -7,6 +7,7 @@ import torch
import torchvision import torchvision
import numpy as np import numpy as np
from termcolor import colored from termcolor import colored
from datetime import datetime
FORMAT_CONFIG = { FORMAT_CONFIG = {
'rl': { 'rl': {
@ -93,8 +94,10 @@ class MetersGroup(object):
class Logger(object): class Logger(object):
def __init__(self, log_dir, use_tb=True, config='rl'): def __init__(self, log_dir, use_tb=True, config='rl'):
self._log_dir = log_dir self._log_dir = log_dir
now = datetime.now()
dt_string = now.strftime("%d_%m_%Y-%H_%M_%S")
if use_tb: if use_tb:
tb_dir = os.path.join(log_dir, 'tb') tb_dir = os.path.join(log_dir, 'runs/tb_'+dt_string)
if os.path.exists(tb_dir): if os.path.exists(tb_dir):
shutil.rmtree(tb_dir) shutil.rmtree(tb_dir)
self._sw = SummaryWriter(tb_dir) self._sw = SummaryWriter(tb_dir)
@ -160,4 +163,4 @@ class Logger(object):
def dump(self, step): def dump(self, step):
self._train_mg.dump(step, 'train') self._train_mg.dump(step, 'train')
self._eval_mg.dump(step, 'eval') self._eval_mg.dump(step, 'eval')