diff --git a/Dreamer/plotting.py b/Dreamer/plotting.py new file mode 100644 index 0000000..49cb85f --- /dev/null +++ b/Dreamer/plotting.py @@ -0,0 +1,339 @@ +import argparse +import collections +import functools +import json +import multiprocessing as mp +import pathlib +import re +import subprocess + +import matplotlib.pyplot as plt +import matplotlib.ticker as ticker +import numpy as np +import pandas as pd + + +# import matplotlib +# matplotlib.rcParams['mathtext.fontset'] = 'stix' +# matplotlib.rcParams['font.family'] = 'STIXGeneral' + +Run = collections.namedtuple("Run", "task method seed xs ys color") + +PALETTE = 10 * ( + "#377eb8", + "#4daf4a", + "#984ea3", + "#e41a1c", + "#ff7f00", + "#a65628", + "#f781bf", + "#888888", + "#a6cee3", + "#b2df8a", + "#cab2d6", + "#fb9a99", + "#fdbf6f", +) + +LEGEND = dict( + fontsize="medium", + numpoints=1, + labelspacing=0, + columnspacing=1.2, + handlelength=1.5, + handletextpad=0.5, + ncol=4, + loc="lower center", +) + + +def find_keys(args): + filename = next(args.indir[0].glob("**/*.jsonl")) + keys = set() + for line in filename.read_text().split("\n"): + if line: + keys |= json.loads(line).keys() + print(f"Keys ({len(keys)}):", ", ".join(keys), flush=True) + + +def load_runs(args): + toload = [] + for indir in args.indir: + filenames = list(indir.glob("**/*.jsonl")) + for filename in filenames: + task, method, seed = filename.relative_to(indir).parts[:-1] + if not any(p.search(task) for p in args.tasks): + continue + if not any(p.search(method) for p in args.methods): + continue + if method not in args.colors: + args.colors[method] = args.palette[len(args.colors)] + toload.append((filename, indir)) + print(f"Loading {len(toload)} of {len(filenames)} runs...") + jobs = [functools.partial(load_run, f, i, args) for f, i in toload] + with mp.Pool(10) as pool: + promises = [pool.apply_async(j) for j in jobs] + runs = [p.get() for p in promises] + runs = [r for r in runs if r is not None] + return runs + + +def load_run(filename, indir, args): + task, method, seed = filename.relative_to(indir).parts[:-1] + num_steps = 1000000 + try: + # Future pandas releases will support JSON files with NaN values. + # df = pd.read_json(filename, lines=True) + with filename.open() as f: + data_arr = [] + for l in f.readlines(): + data = json.loads(l) + data_arr.append(data) + if data["step"] > num_steps: + break + df = pd.DataFrame(data_arr) + except ValueError as e: + print("Invalid", filename.relative_to(indir), e) + return + try: + df = df[[args.xaxis, args.yaxis]].dropna() + except KeyError: + return + xs = df[args.xaxis].to_numpy() + ys = df[args.yaxis].to_numpy() + color = args.colors[method] + return Run(task, method, seed, xs, ys, color) + + +def load_baselines(args): + runs = [] + directory = pathlib.Path(__file__).parent / "baselines" + for filename in directory.glob("**/*.json"): + for task, methods in json.loads(filename.read_text()).items(): + for method, score in methods.items(): + if not any(p.search(method) for p in args.baselines): + continue + if method not in args.colors: + args.colors[method] = args.palette[len(args.colors)] + color = args.colors[method] + runs.append(Run(task, method, None, None, score, color)) + return runs + + +def stats(runs): + baselines = sorted(set(r.method for r in runs if r.xs is None)) + runs = [r for r in runs if r.xs is not None] + tasks = sorted(set(r.task for r in runs)) + methods = sorted(set(r.method for r in runs)) + seeds = sorted(set(r.seed for r in runs)) + print("Loaded", len(runs), "runs.") + print(f"Tasks ({len(tasks)}):", ", ".join(tasks)) + print(f"Methods ({len(methods)}):", ", ".join(methods)) + print(f"Seeds ({len(seeds)}):", ", ".join(seeds)) + print(f"Baselines ({len(baselines)}):", ", ".join(baselines)) + + +def figure(runs, args): + tasks = sorted(set(r.task for r in runs if r.xs is not None)) + rows = int(np.ceil(len(tasks) / args.cols)) + figsize = args.size[0] * args.cols, args.size[1] * rows + fig, axes = plt.subplots(rows, args.cols, figsize=figsize) + for task, ax in zip(tasks, axes.flatten()): + relevant = [r for r in runs if r.task == task] + plot(task, ax, relevant, args) + if args.xlim: + for ax in axes[:-1].flatten(): + ax.xaxis.get_offset_text().set_visible(False) + if args.xlabel: + for ax in axes[-1]: + ax.set_xlabel(args.xlabel) + if args.ylabel: + for ax in axes[:, 0]: + ax.set_ylabel(args.ylabel) + for ax in axes[len(tasks) :]: + ax.axis("off") + legend(fig, args.labels, **LEGEND) + return fig + + +def plot(task, ax, runs, args): + try: + env, task = task.split("_", 1) + title = env.capitalize() + " " + task.capitalize() + # title = task.split('_', 1)[1].replace('_', ' ').title() + except IndexError: + title = task.title() + ax.set_title(title) + methods = [] + methods += sorted(set(r.method for r in runs if r.xs is not None)) + methods += sorted(set(r.method for r in runs if r.xs is None)) + xlim = [+np.inf, -np.inf] + for index, method in enumerate(methods): + relevant = [r for r in runs if r.method == method] + if not relevant: + continue + if any(r.xs is None for r in relevant): + baseline(index, method, ax, relevant, args) + else: + if args.aggregate == "std": + xs, ys = curve_std(index, method, ax, relevant, args) + elif args.aggregate == "none": + xs, ys = curve_individual(index, method, ax, relevant, args) + else: + raise NotImplementedError(args.aggregate) + xlim = [min(xlim[0], xs.min()), max(xlim[1], xs.max())] + ax.ticklabel_format(axis="x", style="sci", scilimits=(0, 0)) + steps = [1, 2, 2.5, 5, 10] + ax.xaxis.set_major_locator(ticker.MaxNLocator(args.xticks, steps=steps)) + ax.yaxis.set_major_locator(ticker.MaxNLocator(args.yticks, steps=steps)) + ax.set_xlim(args.xlim or xlim) + if args.xlim: + ticks = sorted({*ax.get_xticks(), *args.xlim}) + ticks = [x for x in ticks if args.xlim[0] <= x <= args.xlim[1]] + ax.set_xticks(ticks) + if args.ylim: + ax.set_ylim(args.ylim) + ticks = sorted({*ax.get_yticks(), *args.ylim}) + ticks = [x for x in ticks if args.ylim[0] <= x <= args.ylim[1]] + ax.set_yticks(ticks) + + +def curve_individual(index, method, ax, runs, args): + if args.bins: + for index, run in enumerate(runs): + xs, ys = binning(run.xs, run.ys, args.bins, np.nanmean) + runs[index] = run._replace(xs=xs, ys=ys) + zorder = 10000 - 10 * index - 1 + for run in runs: + ax.plot(run.xs, run.ys, label=method, color=run.color, zorder=zorder) + return runs[0].xs, runs[0].ys + + +def curve_std(index, method, ax, runs, args): + if args.bins: + for index, run in enumerate(runs): + xs, ys = binning(run.xs, run.ys, args.bins, np.nanmean) + runs[index] = run._replace(xs=xs, ys=ys) + xs = np.concatenate([r.xs for r in runs]) + ys = np.concatenate([r.ys for r in runs]) + order = np.argsort(xs) + xs, ys = xs[order], ys[order] + color = runs[0].color + if args.bins: + def reducer(y): + return (np.nanmean(np.array(y)), np.nanstd(np.array(y))) + xs, ys = binning(xs, ys, args.bins, reducer) + ys, std = ys.T + kw = dict(color=color, zorder=10000 - 10 * index, alpha=0.1, linewidths=0) + ax.fill_between(xs, ys - std, ys + std, **kw) + ax.plot(xs, ys, label=method, color=color, zorder=10000 - 10 * index - 1) + return xs, ys + + +def baseline(index, method, ax, runs, args): + assert len(runs) == 1 and runs[0].xs is None + y = np.mean(runs[0].ys) + kw = dict(ls="--", color=runs[0].color, zorder=5000 - 10 * index - 1) + ax.axhline(y, label=method, **kw) + + +def binning(xs, ys, bins, reducer): + binned_xs = np.arange(xs.min(), xs.max() + 1e-10, bins) + binned_ys = [] + for start, stop in zip([-np.inf] + list(binned_xs), binned_xs): + left = (xs <= start).sum() + right = (xs <= stop).sum() + binned_ys.append(reducer(ys[left:right])) + binned_ys = np.array(binned_ys) + return binned_xs, binned_ys + + +def legend(fig, mapping=None, **kwargs): + entries = {} + for ax in fig.axes: + for handle, label in zip(*ax.get_legend_handles_labels()): + if mapping and label in mapping: + label = mapping[label] + entries[label] = handle + leg = fig.legend(entries.values(), entries.keys(), **kwargs) + leg.get_frame().set_edgecolor("white") + extent = leg.get_window_extent(fig.canvas.get_renderer()) + extent = extent.transformed(fig.transFigure.inverted()) + yloc, xloc = kwargs["loc"].split() + y0 = dict(lower=extent.y1, center=0, upper=0)[yloc] + y1 = dict(lower=1, center=1, upper=extent.y0)[yloc] + x0 = dict(left=extent.x1, center=0, right=0)[xloc] + x1 = dict(left=1, center=1, right=extent.x0)[xloc] + fig.tight_layout(rect=[x0, y0, x1, y1], h_pad=0.5, w_pad=0.5) + + +def save(fig, args): + args.outdir.mkdir(parents=True, exist_ok=True) + filename = args.outdir / "curves.png" + fig.savefig(filename, dpi=130) + print("Saved to", filename) + filename = args.outdir / "curves.pdf" + fig.savefig(filename) + try: + subprocess.call(["pdfcrop", str(filename), str(filename)]) + except FileNotFoundError: + pass # Install texlive-extra-utils. + + +def main(args): + find_keys(args) + runs = load_runs(args) + load_baselines(args) + stats(runs) + if not runs: + print("Noting to plot.") + return + print("Plotting...") + fig = figure(runs, args) + save(fig, args) + + +def parse_args(): + def boolean(x): + return bool(["False", "True"].index(x)) + parser = argparse.ArgumentParser() + parser.add_argument("--indir", nargs="+", type=pathlib.Path, required=True) + parser.add_argument("--outdir", type=pathlib.Path, required=True) + parser.add_argument("--subdir", type=boolean, default=True) + parser.add_argument("--xaxis", type=str, required=True) + parser.add_argument("--yaxis", type=str, required=True) + parser.add_argument("--tasks", nargs="+", default=[r".*"]) + parser.add_argument("--methods", nargs="+", default=[r".*"]) + parser.add_argument("--baselines", nargs="+", default=[]) + parser.add_argument("--bins", type=float, default=0) + parser.add_argument("--aggregate", type=str, default="std") + parser.add_argument("--size", nargs=2, type=float, default=[2.5, 2.3]) + parser.add_argument("--cols", type=int, default=4) + parser.add_argument("--xlim", nargs=2, type=float, default=None) + parser.add_argument("--ylim", nargs=2, type=float, default=None) + parser.add_argument("--xlabel", type=str, default=None) + parser.add_argument("--ylabel", type=str, default=None) + parser.add_argument("--xticks", type=int, default=6) + parser.add_argument("--yticks", type=int, default=5) + parser.add_argument("--labels", nargs="+", default=None) + parser.add_argument("--palette", nargs="+", default=PALETTE) + parser.add_argument("--colors", nargs="+", default={}) + args = parser.parse_args() + if args.subdir: + args.outdir /= args.indir[0].stem + args.indir = [d.expanduser() for d in args.indir] + args.outdir = args.outdir.expanduser() + if args.labels: + assert len(args.labels) % 2 == 0 + args.labels = {k: v for k, v in zip(args.labels[:-1], args.labels[1:])} + if args.colors: + assert len(args.colors) % 2 == 0 + args.colors = {k: v for k, v in zip(args.colors[:-1], args.colors[1:])} + args.tasks = [re.compile(p) for p in args.tasks] + args.methods = [re.compile(p) for p in args.methods] + args.baselines = [re.compile(p) for p in args.baselines] + args.palette = 10 * args.palette + return args + + +if __name__ == "__main__": + main(parse_args()) \ No newline at end of file diff --git a/Dreamer/run_all_tia.sh b/Dreamer/run_all_tia.sh new file mode 100755 index 0000000..a33f484 --- /dev/null +++ b/Dreamer/run_all_tia.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +for seed in 0 1 2 +do + python \ + run.py \ + --method tia \ + --configs dmc \ + --task dmc_cartpole_swingup_driving \ + --seed $seed +done