DBC/graph_utils.py
2020-10-12 15:39:25 -07:00

314 lines
8.2 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import matplotlib.pyplot as plt
import os
import json
import re
import scipy.interpolate
def read_log_file(file_name, key_name, value_name, smooth=3):
keys, values = [], []
try:
with open(file_name, 'r') as f:
for line in f:
try:
e = json.loads(line.strip())
key, value = e[key_name], e[value_name]
keys.append(int(key))
values.append(float(value))
except:
pass
except:
print('bad file: %s' % file_name)
return None, None
keys, values = np.array(keys), np.array(values)
if smooth > 1 and values.shape[0] > 0:
K = np.ones(smooth)
ones = np.ones(values.shape[0])
values = np.convolve(values, K, 'same') / np.convolve(ones, K, 'same')
return keys, values
def parse_log_files(
file_name_template,
key_name,
value_name,
num_seeds,
smooth,
best_k=None,
max_key=True
):
all_values = []
all_keys = []
actual_keys = None
for seed in range(1, num_seeds + 1):
file_name = file_name_template % seed
keys, values = read_log_file(file_name, key_name, value_name, smooth)
if keys is None or keys.shape[0] == 0:
continue
all_keys.append(keys)
all_values.append(values)
if len(all_values) == 0:
return None, None, None
all_keys_tmp = sorted(all_keys, key=lambda x: x[-1])
keys = all_keys_tmp[-1] if max_key else all_keys_tmp[0]
threshold = keys.shape[0]
# interpolate
for idx, (key, value) in enumerate(zip(all_keys, all_values)):
f = scipy.interpolate.interp1d(key, value, fill_value='extrapolate')
all_keys[idx] = keys
all_values[idx] = f(keys)
means, half_stds = [], []
for i in range(threshold):
vals = []
for v in all_values:
if i < v.shape[0]:
vals.append(v[i])
if best_k is not None:
vals = sorted(vals)[-best_k:]
means.append(np.mean(vals))
half_stds.append(0.5 * np.std(vals))
means = np.array(means)
half_stds = np.array(half_stds)
keys = all_keys[-1][:threshold]
assert means.shape[0] == keys.shape[0]
print(file_name_template, means[-1])
return keys, means, half_stds
# return all_keys, all_values
def print_result(
root,
title,
label=None,
num_seeds=1,
smooth=3,
train=False,
key_name='step',
value_name='episode_reward',
max_time=None,
best_k=None,
timescale=1,
max_key=False
):
file_name = 'train.log' if train else 'eval.log'
file_name_template = os.path.join(root, 'seed_%d', file_name)
keys, means, half_stds = parse_log_files(
file_name_template,
key_name,
value_name,
num_seeds,
smooth=smooth,
best_k=best_k,
max_key=max_key
)
label = label or root.split('/')[-1]
if keys is None:
return
if max_time is not None:
idxs = np.where(keys <= max_time)
keys = keys[idxs]
means = means[idxs]
half_stds = half_stds[idxs]
keys *= timescale
plt.plot(keys, means, label=label)
plt.locator_params(nbins=10, axis='x')
plt.locator_params(nbins=10, axis='y')
plt.rcParams['figure.figsize'] = (10, 7)
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 10
plt.subplots_adjust(left=0.165, right=0.99, bottom=0.16, top=0.95)
#plt.ylim(0, 1050)
plt.tight_layout()
plt.grid(alpha=0.8)
plt.title(title)
plt.fill_between(keys, means - half_stds, means + half_stds, alpha=0.2)
plt.legend(loc='lower right', prop={
'size': 6
}).get_frame().set_edgecolor('0.1')
plt.xlabel(key_name)
plt.ylabel(value_name)
def plot_seeds(
task,
exp_query,
root,
train=True,
smooth=3,
key_name='step',
value_name='episode_reward',
num_seeds=10
):
# root = os.path.join(root, task)
experiment = None
for exp in os.listdir(root):
if re.match(exp_query, exp):
experiment = os.path.join(root, exp)
break
if experiment is None:
return
file_name = 'train.log' if train else 'eval.log'
file_name_template = os.path.join(experiment, 'seed_%d', file_name)
plt.locator_params(nbins=10, axis='x')
plt.locator_params(nbins=10, axis='y')
plt.rcParams['figure.figsize'] = (10, 7)
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 10
plt.subplots_adjust(left=0.165, right=0.99, bottom=0.16, top=0.95)
plt.grid(alpha=0.8)
plt.tight_layout()
plt.title(task)
plt.xlabel(key_name)
plt.ylabel(value_name)
for seed in range(1, num_seeds + 1):
file_name = file_name_template % seed
keys, values = read_log_file(file_name, key_name, value_name, smooth=smooth)
if keys is None or keys.shape[0] == 0:
continue
plt.plot(keys, values, label='seed_%d' % seed, linewidth=0.5)
plt.legend(loc='lower right', prop={
'size': 6
}).get_frame().set_edgecolor('0.1')
def print_baseline(task, baseline, data, color):
try:
value = data[task][baseline]
except:
return
plt.axhline(y=value, label=baseline, linestyle='--', color=color)
plt.legend(loc='lower right', prop={
'size': 6
}).get_frame().set_edgecolor('0.1')
def print_planet_baseline(
task, data, max_time=None, label='planet', color='black', offset=0
):
try:
keys, means, half_stds = data[task]
except:
return
if max_time is not None:
idx = np.searchsorted(keys, max_time)
keys = keys[:idx]
means = means[:idx]
half_stds = half_stds[:idx]
plt.plot(keys + offset, means, label=label, color=color)
plt.fill_between(
keys + offset,
means - half_stds,
means + half_stds,
alpha=0.2,
color=color
)
plt.legend(loc='lower right', prop={
'size': 6
}).get_frame().set_edgecolor('0.1')
def plot_experiment(
task,
exp_query,
neg_exp_query=None,
root='runs',
exp_ids=None,
smooth=3,
train=False,
key_name='step',
value_name='eval_episode_reward',
baselines_data=None,
num_seeds=10,
planet_data=None,
slac_data=None,
max_time=None,
best_k=None,
timescale=1,
max_key=False
):
root = os.path.join(root, task)
experiments = set()
for exp in os.listdir(root):
if re.match(exp_query, exp) and (neg_exp_query is None or re.match(neg_exp_query, exp) is None):
exp = os.path.join(root, exp)
experiments.add(exp)
exp_ids = list(range(len(experiments))) if exp_ids is None else exp_ids
for exp_id, exp in enumerate(sorted(experiments)):
if exp_id in exp_ids:
print_result(
exp,
task,
smooth=smooth,
num_seeds=num_seeds,
train=train,
key_name=key_name,
value_name=value_name,
max_time=max_time,
best_k=best_k,
timescale=timescale,
max_key=max_key
)
if baselines_data is not None:
print_baseline(task, 'd4pg_pixels', baselines_data, color='gray')
print_baseline(task, 'd4pg', baselines_data, color='black')
if planet_data is not None:
print_planet_baseline(
task,
planet_data,
max_time=max_time,
label='planet',
color='peru',
offset=5
)
if slac_data is not None:
action_repeat = {
'ball_in_cup_catch': 4,
'cartpole_swingup': 8,
'cheetah_run': 4,
'finger_spin': 2,
'walker_walk': 2,
'reacher_easy': 4
}
offset = 10 * action_repeat[task]
print_planet_baseline(
task,
slac_data,
max_time=max_time,
label='slac',
color='black',
offset=offset
)