sac_ae_if/graphs_plot.py
2023-05-24 19:43:02 +02:00

49 lines
1.3 KiB
Python

import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
def tabulate_events(dpath):
files = os.listdir(dpath)[0]
summary_iterators = [EventAccumulator(os.path.join(dpath, files)).Reload()]
tags = summary_iterators[0].Tags()['scalars']
for it in summary_iterators:
assert it.Tags()['scalars'] == tags
out = {t: [] for t in tags}
steps = []
for tag in tags:
steps = [e.step for e in summary_iterators[0].Scalars(tag)]
for events in zip(*[acc.Scalars(tag) for acc in summary_iterators]):
assert len(set(e.step for e in events)) == 1
out[tag].append([e.value for e in events])
return out, steps
events, steps = tabulate_events('/home/vedant/pytorch_sac_ae/log/runs')
data = []
for tag, values in events.items():
for run_idx, run_values in enumerate(values):
for step_idx, value in enumerate(run_values):
data.append({
'tag': tag,
'run': run_idx,
'step': steps[step_idx],
'value': value,
})
df = pd.DataFrame(data)
print(df.head())
plt.figure(figsize=(10,6))
sns.lineplot(data=df, x='step', y='value', hue='tag', ci='sd')
plt.show()