Adding Plotting and Shell Script
This commit is contained in:
parent
bd4410e9d0
commit
c33f3e6dc8
339
Dreamer/plotting.py
Normal file
339
Dreamer/plotting.py
Normal file
@ -0,0 +1,339 @@
|
||||
import argparse
|
||||
import collections
|
||||
import functools
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import pathlib
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
# import matplotlib
|
||||
# matplotlib.rcParams['mathtext.fontset'] = 'stix'
|
||||
# matplotlib.rcParams['font.family'] = 'STIXGeneral'
|
||||
|
||||
Run = collections.namedtuple("Run", "task method seed xs ys color")
|
||||
|
||||
PALETTE = 10 * (
|
||||
"#377eb8",
|
||||
"#4daf4a",
|
||||
"#984ea3",
|
||||
"#e41a1c",
|
||||
"#ff7f00",
|
||||
"#a65628",
|
||||
"#f781bf",
|
||||
"#888888",
|
||||
"#a6cee3",
|
||||
"#b2df8a",
|
||||
"#cab2d6",
|
||||
"#fb9a99",
|
||||
"#fdbf6f",
|
||||
)
|
||||
|
||||
LEGEND = dict(
|
||||
fontsize="medium",
|
||||
numpoints=1,
|
||||
labelspacing=0,
|
||||
columnspacing=1.2,
|
||||
handlelength=1.5,
|
||||
handletextpad=0.5,
|
||||
ncol=4,
|
||||
loc="lower center",
|
||||
)
|
||||
|
||||
|
||||
def find_keys(args):
|
||||
filename = next(args.indir[0].glob("**/*.jsonl"))
|
||||
keys = set()
|
||||
for line in filename.read_text().split("\n"):
|
||||
if line:
|
||||
keys |= json.loads(line).keys()
|
||||
print(f"Keys ({len(keys)}):", ", ".join(keys), flush=True)
|
||||
|
||||
|
||||
def load_runs(args):
|
||||
toload = []
|
||||
for indir in args.indir:
|
||||
filenames = list(indir.glob("**/*.jsonl"))
|
||||
for filename in filenames:
|
||||
task, method, seed = filename.relative_to(indir).parts[:-1]
|
||||
if not any(p.search(task) for p in args.tasks):
|
||||
continue
|
||||
if not any(p.search(method) for p in args.methods):
|
||||
continue
|
||||
if method not in args.colors:
|
||||
args.colors[method] = args.palette[len(args.colors)]
|
||||
toload.append((filename, indir))
|
||||
print(f"Loading {len(toload)} of {len(filenames)} runs...")
|
||||
jobs = [functools.partial(load_run, f, i, args) for f, i in toload]
|
||||
with mp.Pool(10) as pool:
|
||||
promises = [pool.apply_async(j) for j in jobs]
|
||||
runs = [p.get() for p in promises]
|
||||
runs = [r for r in runs if r is not None]
|
||||
return runs
|
||||
|
||||
|
||||
def load_run(filename, indir, args):
|
||||
task, method, seed = filename.relative_to(indir).parts[:-1]
|
||||
num_steps = 1000000
|
||||
try:
|
||||
# Future pandas releases will support JSON files with NaN values.
|
||||
# df = pd.read_json(filename, lines=True)
|
||||
with filename.open() as f:
|
||||
data_arr = []
|
||||
for l in f.readlines():
|
||||
data = json.loads(l)
|
||||
data_arr.append(data)
|
||||
if data["step"] > num_steps:
|
||||
break
|
||||
df = pd.DataFrame(data_arr)
|
||||
except ValueError as e:
|
||||
print("Invalid", filename.relative_to(indir), e)
|
||||
return
|
||||
try:
|
||||
df = df[[args.xaxis, args.yaxis]].dropna()
|
||||
except KeyError:
|
||||
return
|
||||
xs = df[args.xaxis].to_numpy()
|
||||
ys = df[args.yaxis].to_numpy()
|
||||
color = args.colors[method]
|
||||
return Run(task, method, seed, xs, ys, color)
|
||||
|
||||
|
||||
def load_baselines(args):
|
||||
runs = []
|
||||
directory = pathlib.Path(__file__).parent / "baselines"
|
||||
for filename in directory.glob("**/*.json"):
|
||||
for task, methods in json.loads(filename.read_text()).items():
|
||||
for method, score in methods.items():
|
||||
if not any(p.search(method) for p in args.baselines):
|
||||
continue
|
||||
if method not in args.colors:
|
||||
args.colors[method] = args.palette[len(args.colors)]
|
||||
color = args.colors[method]
|
||||
runs.append(Run(task, method, None, None, score, color))
|
||||
return runs
|
||||
|
||||
|
||||
def stats(runs):
|
||||
baselines = sorted(set(r.method for r in runs if r.xs is None))
|
||||
runs = [r for r in runs if r.xs is not None]
|
||||
tasks = sorted(set(r.task for r in runs))
|
||||
methods = sorted(set(r.method for r in runs))
|
||||
seeds = sorted(set(r.seed for r in runs))
|
||||
print("Loaded", len(runs), "runs.")
|
||||
print(f"Tasks ({len(tasks)}):", ", ".join(tasks))
|
||||
print(f"Methods ({len(methods)}):", ", ".join(methods))
|
||||
print(f"Seeds ({len(seeds)}):", ", ".join(seeds))
|
||||
print(f"Baselines ({len(baselines)}):", ", ".join(baselines))
|
||||
|
||||
|
||||
def figure(runs, args):
|
||||
tasks = sorted(set(r.task for r in runs if r.xs is not None))
|
||||
rows = int(np.ceil(len(tasks) / args.cols))
|
||||
figsize = args.size[0] * args.cols, args.size[1] * rows
|
||||
fig, axes = plt.subplots(rows, args.cols, figsize=figsize)
|
||||
for task, ax in zip(tasks, axes.flatten()):
|
||||
relevant = [r for r in runs if r.task == task]
|
||||
plot(task, ax, relevant, args)
|
||||
if args.xlim:
|
||||
for ax in axes[:-1].flatten():
|
||||
ax.xaxis.get_offset_text().set_visible(False)
|
||||
if args.xlabel:
|
||||
for ax in axes[-1]:
|
||||
ax.set_xlabel(args.xlabel)
|
||||
if args.ylabel:
|
||||
for ax in axes[:, 0]:
|
||||
ax.set_ylabel(args.ylabel)
|
||||
for ax in axes[len(tasks) :]:
|
||||
ax.axis("off")
|
||||
legend(fig, args.labels, **LEGEND)
|
||||
return fig
|
||||
|
||||
|
||||
def plot(task, ax, runs, args):
|
||||
try:
|
||||
env, task = task.split("_", 1)
|
||||
title = env.capitalize() + " " + task.capitalize()
|
||||
# title = task.split('_', 1)[1].replace('_', ' ').title()
|
||||
except IndexError:
|
||||
title = task.title()
|
||||
ax.set_title(title)
|
||||
methods = []
|
||||
methods += sorted(set(r.method for r in runs if r.xs is not None))
|
||||
methods += sorted(set(r.method for r in runs if r.xs is None))
|
||||
xlim = [+np.inf, -np.inf]
|
||||
for index, method in enumerate(methods):
|
||||
relevant = [r for r in runs if r.method == method]
|
||||
if not relevant:
|
||||
continue
|
||||
if any(r.xs is None for r in relevant):
|
||||
baseline(index, method, ax, relevant, args)
|
||||
else:
|
||||
if args.aggregate == "std":
|
||||
xs, ys = curve_std(index, method, ax, relevant, args)
|
||||
elif args.aggregate == "none":
|
||||
xs, ys = curve_individual(index, method, ax, relevant, args)
|
||||
else:
|
||||
raise NotImplementedError(args.aggregate)
|
||||
xlim = [min(xlim[0], xs.min()), max(xlim[1], xs.max())]
|
||||
ax.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
|
||||
steps = [1, 2, 2.5, 5, 10]
|
||||
ax.xaxis.set_major_locator(ticker.MaxNLocator(args.xticks, steps=steps))
|
||||
ax.yaxis.set_major_locator(ticker.MaxNLocator(args.yticks, steps=steps))
|
||||
ax.set_xlim(args.xlim or xlim)
|
||||
if args.xlim:
|
||||
ticks = sorted({*ax.get_xticks(), *args.xlim})
|
||||
ticks = [x for x in ticks if args.xlim[0] <= x <= args.xlim[1]]
|
||||
ax.set_xticks(ticks)
|
||||
if args.ylim:
|
||||
ax.set_ylim(args.ylim)
|
||||
ticks = sorted({*ax.get_yticks(), *args.ylim})
|
||||
ticks = [x for x in ticks if args.ylim[0] <= x <= args.ylim[1]]
|
||||
ax.set_yticks(ticks)
|
||||
|
||||
|
||||
def curve_individual(index, method, ax, runs, args):
|
||||
if args.bins:
|
||||
for index, run in enumerate(runs):
|
||||
xs, ys = binning(run.xs, run.ys, args.bins, np.nanmean)
|
||||
runs[index] = run._replace(xs=xs, ys=ys)
|
||||
zorder = 10000 - 10 * index - 1
|
||||
for run in runs:
|
||||
ax.plot(run.xs, run.ys, label=method, color=run.color, zorder=zorder)
|
||||
return runs[0].xs, runs[0].ys
|
||||
|
||||
|
||||
def curve_std(index, method, ax, runs, args):
|
||||
if args.bins:
|
||||
for index, run in enumerate(runs):
|
||||
xs, ys = binning(run.xs, run.ys, args.bins, np.nanmean)
|
||||
runs[index] = run._replace(xs=xs, ys=ys)
|
||||
xs = np.concatenate([r.xs for r in runs])
|
||||
ys = np.concatenate([r.ys for r in runs])
|
||||
order = np.argsort(xs)
|
||||
xs, ys = xs[order], ys[order]
|
||||
color = runs[0].color
|
||||
if args.bins:
|
||||
def reducer(y):
|
||||
return (np.nanmean(np.array(y)), np.nanstd(np.array(y)))
|
||||
xs, ys = binning(xs, ys, args.bins, reducer)
|
||||
ys, std = ys.T
|
||||
kw = dict(color=color, zorder=10000 - 10 * index, alpha=0.1, linewidths=0)
|
||||
ax.fill_between(xs, ys - std, ys + std, **kw)
|
||||
ax.plot(xs, ys, label=method, color=color, zorder=10000 - 10 * index - 1)
|
||||
return xs, ys
|
||||
|
||||
|
||||
def baseline(index, method, ax, runs, args):
|
||||
assert len(runs) == 1 and runs[0].xs is None
|
||||
y = np.mean(runs[0].ys)
|
||||
kw = dict(ls="--", color=runs[0].color, zorder=5000 - 10 * index - 1)
|
||||
ax.axhline(y, label=method, **kw)
|
||||
|
||||
|
||||
def binning(xs, ys, bins, reducer):
|
||||
binned_xs = np.arange(xs.min(), xs.max() + 1e-10, bins)
|
||||
binned_ys = []
|
||||
for start, stop in zip([-np.inf] + list(binned_xs), binned_xs):
|
||||
left = (xs <= start).sum()
|
||||
right = (xs <= stop).sum()
|
||||
binned_ys.append(reducer(ys[left:right]))
|
||||
binned_ys = np.array(binned_ys)
|
||||
return binned_xs, binned_ys
|
||||
|
||||
|
||||
def legend(fig, mapping=None, **kwargs):
|
||||
entries = {}
|
||||
for ax in fig.axes:
|
||||
for handle, label in zip(*ax.get_legend_handles_labels()):
|
||||
if mapping and label in mapping:
|
||||
label = mapping[label]
|
||||
entries[label] = handle
|
||||
leg = fig.legend(entries.values(), entries.keys(), **kwargs)
|
||||
leg.get_frame().set_edgecolor("white")
|
||||
extent = leg.get_window_extent(fig.canvas.get_renderer())
|
||||
extent = extent.transformed(fig.transFigure.inverted())
|
||||
yloc, xloc = kwargs["loc"].split()
|
||||
y0 = dict(lower=extent.y1, center=0, upper=0)[yloc]
|
||||
y1 = dict(lower=1, center=1, upper=extent.y0)[yloc]
|
||||
x0 = dict(left=extent.x1, center=0, right=0)[xloc]
|
||||
x1 = dict(left=1, center=1, right=extent.x0)[xloc]
|
||||
fig.tight_layout(rect=[x0, y0, x1, y1], h_pad=0.5, w_pad=0.5)
|
||||
|
||||
|
||||
def save(fig, args):
|
||||
args.outdir.mkdir(parents=True, exist_ok=True)
|
||||
filename = args.outdir / "curves.png"
|
||||
fig.savefig(filename, dpi=130)
|
||||
print("Saved to", filename)
|
||||
filename = args.outdir / "curves.pdf"
|
||||
fig.savefig(filename)
|
||||
try:
|
||||
subprocess.call(["pdfcrop", str(filename), str(filename)])
|
||||
except FileNotFoundError:
|
||||
pass # Install texlive-extra-utils.
|
||||
|
||||
|
||||
def main(args):
|
||||
find_keys(args)
|
||||
runs = load_runs(args) + load_baselines(args)
|
||||
stats(runs)
|
||||
if not runs:
|
||||
print("Noting to plot.")
|
||||
return
|
||||
print("Plotting...")
|
||||
fig = figure(runs, args)
|
||||
save(fig, args)
|
||||
|
||||
|
||||
def parse_args():
|
||||
def boolean(x):
|
||||
return bool(["False", "True"].index(x))
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--indir", nargs="+", type=pathlib.Path, required=True)
|
||||
parser.add_argument("--outdir", type=pathlib.Path, required=True)
|
||||
parser.add_argument("--subdir", type=boolean, default=True)
|
||||
parser.add_argument("--xaxis", type=str, required=True)
|
||||
parser.add_argument("--yaxis", type=str, required=True)
|
||||
parser.add_argument("--tasks", nargs="+", default=[r".*"])
|
||||
parser.add_argument("--methods", nargs="+", default=[r".*"])
|
||||
parser.add_argument("--baselines", nargs="+", default=[])
|
||||
parser.add_argument("--bins", type=float, default=0)
|
||||
parser.add_argument("--aggregate", type=str, default="std")
|
||||
parser.add_argument("--size", nargs=2, type=float, default=[2.5, 2.3])
|
||||
parser.add_argument("--cols", type=int, default=4)
|
||||
parser.add_argument("--xlim", nargs=2, type=float, default=None)
|
||||
parser.add_argument("--ylim", nargs=2, type=float, default=None)
|
||||
parser.add_argument("--xlabel", type=str, default=None)
|
||||
parser.add_argument("--ylabel", type=str, default=None)
|
||||
parser.add_argument("--xticks", type=int, default=6)
|
||||
parser.add_argument("--yticks", type=int, default=5)
|
||||
parser.add_argument("--labels", nargs="+", default=None)
|
||||
parser.add_argument("--palette", nargs="+", default=PALETTE)
|
||||
parser.add_argument("--colors", nargs="+", default={})
|
||||
args = parser.parse_args()
|
||||
if args.subdir:
|
||||
args.outdir /= args.indir[0].stem
|
||||
args.indir = [d.expanduser() for d in args.indir]
|
||||
args.outdir = args.outdir.expanduser()
|
||||
if args.labels:
|
||||
assert len(args.labels) % 2 == 0
|
||||
args.labels = {k: v for k, v in zip(args.labels[:-1], args.labels[1:])}
|
||||
if args.colors:
|
||||
assert len(args.colors) % 2 == 0
|
||||
args.colors = {k: v for k, v in zip(args.colors[:-1], args.colors[1:])}
|
||||
args.tasks = [re.compile(p) for p in args.tasks]
|
||||
args.methods = [re.compile(p) for p in args.methods]
|
||||
args.baselines = [re.compile(p) for p in args.baselines]
|
||||
args.palette = 10 * args.palette
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(parse_args())
|
11
Dreamer/run_all_tia.sh
Executable file
11
Dreamer/run_all_tia.sh
Executable file
@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
|
||||
for seed in 0 1 2
|
||||
do
|
||||
python \
|
||||
run.py \
|
||||
--method tia \
|
||||
--configs dmc \
|
||||
--task dmc_cartpole_swingup_driving \
|
||||
--seed $seed
|
||||
done
|
Loading…
Reference in New Issue
Block a user