339 lines
12 KiB
Python
339 lines
12 KiB
Python
|
import argparse
|
||
|
import collections
|
||
|
import functools
|
||
|
import json
|
||
|
import multiprocessing as mp
|
||
|
import pathlib
|
||
|
import re
|
||
|
import subprocess
|
||
|
|
||
|
import matplotlib.pyplot as plt
|
||
|
import matplotlib.ticker as ticker
|
||
|
import numpy as np
|
||
|
import pandas as pd
|
||
|
|
||
|
|
||
|
# import matplotlib
|
||
|
# matplotlib.rcParams['mathtext.fontset'] = 'stix'
|
||
|
# matplotlib.rcParams['font.family'] = 'STIXGeneral'
|
||
|
|
||
|
Run = collections.namedtuple("Run", "task method seed xs ys color")
|
||
|
|
||
|
PALETTE = 10 * (
|
||
|
"#377eb8",
|
||
|
"#4daf4a",
|
||
|
"#984ea3",
|
||
|
"#e41a1c",
|
||
|
"#ff7f00",
|
||
|
"#a65628",
|
||
|
"#f781bf",
|
||
|
"#888888",
|
||
|
"#a6cee3",
|
||
|
"#b2df8a",
|
||
|
"#cab2d6",
|
||
|
"#fb9a99",
|
||
|
"#fdbf6f",
|
||
|
)
|
||
|
|
||
|
LEGEND = dict(
|
||
|
fontsize="medium",
|
||
|
numpoints=1,
|
||
|
labelspacing=0,
|
||
|
columnspacing=1.2,
|
||
|
handlelength=1.5,
|
||
|
handletextpad=0.5,
|
||
|
ncol=4,
|
||
|
loc="lower center",
|
||
|
)
|
||
|
|
||
|
|
||
|
def find_keys(args):
|
||
|
filename = next(args.indir[0].glob("**/*.jsonl"))
|
||
|
keys = set()
|
||
|
for line in filename.read_text().split("\n"):
|
||
|
if line:
|
||
|
keys |= json.loads(line).keys()
|
||
|
print(f"Keys ({len(keys)}):", ", ".join(keys), flush=True)
|
||
|
|
||
|
|
||
|
def load_runs(args):
|
||
|
toload = []
|
||
|
for indir in args.indir:
|
||
|
filenames = list(indir.glob("**/*.jsonl"))
|
||
|
for filename in filenames:
|
||
|
task, method, seed = filename.relative_to(indir).parts[:-1]
|
||
|
if not any(p.search(task) for p in args.tasks):
|
||
|
continue
|
||
|
if not any(p.search(method) for p in args.methods):
|
||
|
continue
|
||
|
if method not in args.colors:
|
||
|
args.colors[method] = args.palette[len(args.colors)]
|
||
|
toload.append((filename, indir))
|
||
|
print(f"Loading {len(toload)} of {len(filenames)} runs...")
|
||
|
jobs = [functools.partial(load_run, f, i, args) for f, i in toload]
|
||
|
with mp.Pool(10) as pool:
|
||
|
promises = [pool.apply_async(j) for j in jobs]
|
||
|
runs = [p.get() for p in promises]
|
||
|
runs = [r for r in runs if r is not None]
|
||
|
return runs
|
||
|
|
||
|
|
||
|
def load_run(filename, indir, args):
|
||
|
task, method, seed = filename.relative_to(indir).parts[:-1]
|
||
|
num_steps = 1000000
|
||
|
try:
|
||
|
# Future pandas releases will support JSON files with NaN values.
|
||
|
# df = pd.read_json(filename, lines=True)
|
||
|
with filename.open() as f:
|
||
|
data_arr = []
|
||
|
for l in f.readlines():
|
||
|
data = json.loads(l)
|
||
|
data_arr.append(data)
|
||
|
if data["step"] > num_steps:
|
||
|
break
|
||
|
df = pd.DataFrame(data_arr)
|
||
|
except ValueError as e:
|
||
|
print("Invalid", filename.relative_to(indir), e)
|
||
|
return
|
||
|
try:
|
||
|
df = df[[args.xaxis, args.yaxis]].dropna()
|
||
|
except KeyError:
|
||
|
return
|
||
|
xs = df[args.xaxis].to_numpy()
|
||
|
ys = df[args.yaxis].to_numpy()
|
||
|
color = args.colors[method]
|
||
|
return Run(task, method, seed, xs, ys, color)
|
||
|
|
||
|
|
||
|
def load_baselines(args):
|
||
|
runs = []
|
||
|
directory = pathlib.Path(__file__).parent / "baselines"
|
||
|
for filename in directory.glob("**/*.json"):
|
||
|
for task, methods in json.loads(filename.read_text()).items():
|
||
|
for method, score in methods.items():
|
||
|
if not any(p.search(method) for p in args.baselines):
|
||
|
continue
|
||
|
if method not in args.colors:
|
||
|
args.colors[method] = args.palette[len(args.colors)]
|
||
|
color = args.colors[method]
|
||
|
runs.append(Run(task, method, None, None, score, color))
|
||
|
return runs
|
||
|
|
||
|
|
||
|
def stats(runs):
|
||
|
baselines = sorted(set(r.method for r in runs if r.xs is None))
|
||
|
runs = [r for r in runs if r.xs is not None]
|
||
|
tasks = sorted(set(r.task for r in runs))
|
||
|
methods = sorted(set(r.method for r in runs))
|
||
|
seeds = sorted(set(r.seed for r in runs))
|
||
|
print("Loaded", len(runs), "runs.")
|
||
|
print(f"Tasks ({len(tasks)}):", ", ".join(tasks))
|
||
|
print(f"Methods ({len(methods)}):", ", ".join(methods))
|
||
|
print(f"Seeds ({len(seeds)}):", ", ".join(seeds))
|
||
|
print(f"Baselines ({len(baselines)}):", ", ".join(baselines))
|
||
|
|
||
|
|
||
|
def figure(runs, args):
|
||
|
tasks = sorted(set(r.task for r in runs if r.xs is not None))
|
||
|
rows = int(np.ceil(len(tasks) / args.cols))
|
||
|
figsize = args.size[0] * args.cols, args.size[1] * rows
|
||
|
fig, axes = plt.subplots(rows, args.cols, figsize=figsize)
|
||
|
for task, ax in zip(tasks, axes.flatten()):
|
||
|
relevant = [r for r in runs if r.task == task]
|
||
|
plot(task, ax, relevant, args)
|
||
|
if args.xlim:
|
||
|
for ax in axes[:-1].flatten():
|
||
|
ax.xaxis.get_offset_text().set_visible(False)
|
||
|
if args.xlabel:
|
||
|
for ax in axes[-1]:
|
||
|
ax.set_xlabel(args.xlabel)
|
||
|
if args.ylabel:
|
||
|
for ax in axes[:, 0]:
|
||
|
ax.set_ylabel(args.ylabel)
|
||
|
for ax in axes[len(tasks) :]:
|
||
|
ax.axis("off")
|
||
|
legend(fig, args.labels, **LEGEND)
|
||
|
return fig
|
||
|
|
||
|
|
||
|
def plot(task, ax, runs, args):
|
||
|
try:
|
||
|
env, task = task.split("_", 1)
|
||
|
title = env.capitalize() + " " + task.capitalize()
|
||
|
# title = task.split('_', 1)[1].replace('_', ' ').title()
|
||
|
except IndexError:
|
||
|
title = task.title()
|
||
|
ax.set_title(title)
|
||
|
methods = []
|
||
|
methods += sorted(set(r.method for r in runs if r.xs is not None))
|
||
|
methods += sorted(set(r.method for r in runs if r.xs is None))
|
||
|
xlim = [+np.inf, -np.inf]
|
||
|
for index, method in enumerate(methods):
|
||
|
relevant = [r for r in runs if r.method == method]
|
||
|
if not relevant:
|
||
|
continue
|
||
|
if any(r.xs is None for r in relevant):
|
||
|
baseline(index, method, ax, relevant, args)
|
||
|
else:
|
||
|
if args.aggregate == "std":
|
||
|
xs, ys = curve_std(index, method, ax, relevant, args)
|
||
|
elif args.aggregate == "none":
|
||
|
xs, ys = curve_individual(index, method, ax, relevant, args)
|
||
|
else:
|
||
|
raise NotImplementedError(args.aggregate)
|
||
|
xlim = [min(xlim[0], xs.min()), max(xlim[1], xs.max())]
|
||
|
ax.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
|
||
|
steps = [1, 2, 2.5, 5, 10]
|
||
|
ax.xaxis.set_major_locator(ticker.MaxNLocator(args.xticks, steps=steps))
|
||
|
ax.yaxis.set_major_locator(ticker.MaxNLocator(args.yticks, steps=steps))
|
||
|
ax.set_xlim(args.xlim or xlim)
|
||
|
if args.xlim:
|
||
|
ticks = sorted({*ax.get_xticks(), *args.xlim})
|
||
|
ticks = [x for x in ticks if args.xlim[0] <= x <= args.xlim[1]]
|
||
|
ax.set_xticks(ticks)
|
||
|
if args.ylim:
|
||
|
ax.set_ylim(args.ylim)
|
||
|
ticks = sorted({*ax.get_yticks(), *args.ylim})
|
||
|
ticks = [x for x in ticks if args.ylim[0] <= x <= args.ylim[1]]
|
||
|
ax.set_yticks(ticks)
|
||
|
|
||
|
|
||
|
def curve_individual(index, method, ax, runs, args):
|
||
|
if args.bins:
|
||
|
for index, run in enumerate(runs):
|
||
|
xs, ys = binning(run.xs, run.ys, args.bins, np.nanmean)
|
||
|
runs[index] = run._replace(xs=xs, ys=ys)
|
||
|
zorder = 10000 - 10 * index - 1
|
||
|
for run in runs:
|
||
|
ax.plot(run.xs, run.ys, label=method, color=run.color, zorder=zorder)
|
||
|
return runs[0].xs, runs[0].ys
|
||
|
|
||
|
|
||
|
def curve_std(index, method, ax, runs, args):
|
||
|
if args.bins:
|
||
|
for index, run in enumerate(runs):
|
||
|
xs, ys = binning(run.xs, run.ys, args.bins, np.nanmean)
|
||
|
runs[index] = run._replace(xs=xs, ys=ys)
|
||
|
xs = np.concatenate([r.xs for r in runs])
|
||
|
ys = np.concatenate([r.ys for r in runs])
|
||
|
order = np.argsort(xs)
|
||
|
xs, ys = xs[order], ys[order]
|
||
|
color = runs[0].color
|
||
|
if args.bins:
|
||
|
def reducer(y):
|
||
|
return (np.nanmean(np.array(y)), np.nanstd(np.array(y)))
|
||
|
xs, ys = binning(xs, ys, args.bins, reducer)
|
||
|
ys, std = ys.T
|
||
|
kw = dict(color=color, zorder=10000 - 10 * index, alpha=0.1, linewidths=0)
|
||
|
ax.fill_between(xs, ys - std, ys + std, **kw)
|
||
|
ax.plot(xs, ys, label=method, color=color, zorder=10000 - 10 * index - 1)
|
||
|
return xs, ys
|
||
|
|
||
|
|
||
|
def baseline(index, method, ax, runs, args):
|
||
|
assert len(runs) == 1 and runs[0].xs is None
|
||
|
y = np.mean(runs[0].ys)
|
||
|
kw = dict(ls="--", color=runs[0].color, zorder=5000 - 10 * index - 1)
|
||
|
ax.axhline(y, label=method, **kw)
|
||
|
|
||
|
|
||
|
def binning(xs, ys, bins, reducer):
|
||
|
binned_xs = np.arange(xs.min(), xs.max() + 1e-10, bins)
|
||
|
binned_ys = []
|
||
|
for start, stop in zip([-np.inf] + list(binned_xs), binned_xs):
|
||
|
left = (xs <= start).sum()
|
||
|
right = (xs <= stop).sum()
|
||
|
binned_ys.append(reducer(ys[left:right]))
|
||
|
binned_ys = np.array(binned_ys)
|
||
|
return binned_xs, binned_ys
|
||
|
|
||
|
|
||
|
def legend(fig, mapping=None, **kwargs):
|
||
|
entries = {}
|
||
|
for ax in fig.axes:
|
||
|
for handle, label in zip(*ax.get_legend_handles_labels()):
|
||
|
if mapping and label in mapping:
|
||
|
label = mapping[label]
|
||
|
entries[label] = handle
|
||
|
leg = fig.legend(entries.values(), entries.keys(), **kwargs)
|
||
|
leg.get_frame().set_edgecolor("white")
|
||
|
extent = leg.get_window_extent(fig.canvas.get_renderer())
|
||
|
extent = extent.transformed(fig.transFigure.inverted())
|
||
|
yloc, xloc = kwargs["loc"].split()
|
||
|
y0 = dict(lower=extent.y1, center=0, upper=0)[yloc]
|
||
|
y1 = dict(lower=1, center=1, upper=extent.y0)[yloc]
|
||
|
x0 = dict(left=extent.x1, center=0, right=0)[xloc]
|
||
|
x1 = dict(left=1, center=1, right=extent.x0)[xloc]
|
||
|
fig.tight_layout(rect=[x0, y0, x1, y1], h_pad=0.5, w_pad=0.5)
|
||
|
|
||
|
|
||
|
def save(fig, args):
|
||
|
args.outdir.mkdir(parents=True, exist_ok=True)
|
||
|
filename = args.outdir / "curves.png"
|
||
|
fig.savefig(filename, dpi=130)
|
||
|
print("Saved to", filename)
|
||
|
filename = args.outdir / "curves.pdf"
|
||
|
fig.savefig(filename)
|
||
|
try:
|
||
|
subprocess.call(["pdfcrop", str(filename), str(filename)])
|
||
|
except FileNotFoundError:
|
||
|
pass # Install texlive-extra-utils.
|
||
|
|
||
|
|
||
|
def main(args):
|
||
|
find_keys(args)
|
||
|
runs = load_runs(args) + load_baselines(args)
|
||
|
stats(runs)
|
||
|
if not runs:
|
||
|
print("Noting to plot.")
|
||
|
return
|
||
|
print("Plotting...")
|
||
|
fig = figure(runs, args)
|
||
|
save(fig, args)
|
||
|
|
||
|
|
||
|
def parse_args():
|
||
|
def boolean(x):
|
||
|
return bool(["False", "True"].index(x))
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("--indir", nargs="+", type=pathlib.Path, required=True)
|
||
|
parser.add_argument("--outdir", type=pathlib.Path, required=True)
|
||
|
parser.add_argument("--subdir", type=boolean, default=True)
|
||
|
parser.add_argument("--xaxis", type=str, required=True)
|
||
|
parser.add_argument("--yaxis", type=str, required=True)
|
||
|
parser.add_argument("--tasks", nargs="+", default=[r".*"])
|
||
|
parser.add_argument("--methods", nargs="+", default=[r".*"])
|
||
|
parser.add_argument("--baselines", nargs="+", default=[])
|
||
|
parser.add_argument("--bins", type=float, default=0)
|
||
|
parser.add_argument("--aggregate", type=str, default="std")
|
||
|
parser.add_argument("--size", nargs=2, type=float, default=[2.5, 2.3])
|
||
|
parser.add_argument("--cols", type=int, default=4)
|
||
|
parser.add_argument("--xlim", nargs=2, type=float, default=None)
|
||
|
parser.add_argument("--ylim", nargs=2, type=float, default=None)
|
||
|
parser.add_argument("--xlabel", type=str, default=None)
|
||
|
parser.add_argument("--ylabel", type=str, default=None)
|
||
|
parser.add_argument("--xticks", type=int, default=6)
|
||
|
parser.add_argument("--yticks", type=int, default=5)
|
||
|
parser.add_argument("--labels", nargs="+", default=None)
|
||
|
parser.add_argument("--palette", nargs="+", default=PALETTE)
|
||
|
parser.add_argument("--colors", nargs="+", default={})
|
||
|
args = parser.parse_args()
|
||
|
if args.subdir:
|
||
|
args.outdir /= args.indir[0].stem
|
||
|
args.indir = [d.expanduser() for d in args.indir]
|
||
|
args.outdir = args.outdir.expanduser()
|
||
|
if args.labels:
|
||
|
assert len(args.labels) % 2 == 0
|
||
|
args.labels = {k: v for k, v in zip(args.labels[:-1], args.labels[1:])}
|
||
|
if args.colors:
|
||
|
assert len(args.colors) % 2 == 0
|
||
|
args.colors = {k: v for k, v in zip(args.colors[:-1], args.colors[1:])}
|
||
|
args.tasks = [re.compile(p) for p in args.tasks]
|
||
|
args.methods = [re.compile(p) for p in args.methods]
|
||
|
args.baselines = [re.compile(p) for p in args.baselines]
|
||
|
args.palette = 10 * args.palette
|
||
|
return args
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main(parse_args())
|