Compare commits
6 Commits
0ac3131dad
...
3b61469681
Author | SHA1 | Date | |
---|---|---|---|
3b61469681 | |||
a0cb89820e | |||
c33f3e6dc8 | |||
bd4410e9d0 | |||
|
a8b9de1e7e | ||
|
297b5bb62d |
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.egg-info
|
||||
./dist
|
||||
MUJOCO_LOG.TXT
|
211
Dreamer/graph_plot.py
Normal file
211
Dreamer/graph_plot.py
Normal file
@ -0,0 +1,211 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
|
||||
|
||||
def binning(xs, ys, bins, reducer):
|
||||
binned_xs = np.arange(xs.min(), xs.max() + 1e-10, bins)
|
||||
binned_ys = []
|
||||
for start, stop in zip([-np.inf] + list(binned_xs), binned_xs):
|
||||
left = (xs <= start).sum()
|
||||
right = (xs <= stop).sum()
|
||||
binned_ys.append(reducer(ys[left:right]))
|
||||
binned_ys = np.array(binned_ys)
|
||||
return binned_xs, binned_ys
|
||||
|
||||
|
||||
def plot_data(parent_dir, tag_filter="test/return", xaxis='step', value="AverageEpRet", condition="Condition1", smooth=1, bins=30000, xticks=5, yticks=5):
|
||||
# List to store all DataFrames
|
||||
data = []
|
||||
|
||||
# Traversing through each subfolder in the parent directory
|
||||
for subfolder in os.listdir(parent_dir):
|
||||
json_dir = os.path.join(parent_dir, subfolder)
|
||||
if not os.path.isdir(json_dir):
|
||||
continue
|
||||
# Read each JSON file separately
|
||||
for json_file in os.listdir(json_dir):
|
||||
if not json_file.endswith('.jsonl'):
|
||||
continue
|
||||
# Read the data from the JSON file
|
||||
df = pd.read_json(os.path.join(json_dir, json_file), lines=True)
|
||||
|
||||
# Check if tag_filter exists in DataFrame
|
||||
if tag_filter not in df.columns:
|
||||
continue
|
||||
|
||||
df = df[['step', tag_filter]].dropna().sort_values(by='step')
|
||||
|
||||
# Apply binning
|
||||
xs, ys = binning(df['step'].to_numpy(), df[tag_filter].to_numpy(), bins, np.nanmean)
|
||||
|
||||
# Replace original data with binned data
|
||||
df = pd.DataFrame({ 'step': xs, tag_filter: ys })
|
||||
|
||||
# Append the DataFrame to the list
|
||||
data.append(df)
|
||||
|
||||
# Combine all DataFrames
|
||||
combined_df = pd.concat(data, ignore_index=True)
|
||||
|
||||
# Plotting the combined DataFrame
|
||||
sns.set(style="white", font_scale=1.5)
|
||||
plot = sns.lineplot(data=combined_df, x=xaxis, y=tag_filter, errorbar='sd')
|
||||
|
||||
ax = plot.axes
|
||||
ax.ticklabel_format(axis="x", scilimits=(5, 5))
|
||||
steps = [1, 2, 2.5, 5, 10]
|
||||
ax.xaxis.set_major_locator(ticker.MaxNLocator(xticks, steps=steps))
|
||||
ax.yaxis.set_major_locator(ticker.MaxNLocator(yticks, steps=steps))
|
||||
|
||||
xlim = [+np.inf, -np.inf]
|
||||
xlim = [min(xlim[0], xs.min()), max(xlim[1], xs.max())]
|
||||
ax.set_xlim(xlim)
|
||||
#plt.xlim([0, max])
|
||||
|
||||
#plt.legend(loc='best').set_draggable(True)
|
||||
plt.tight_layout(pad=0.5)
|
||||
plt.show()
|
||||
|
||||
# Call the function
|
||||
plot_data('/media/vedant/cpsDataStorageWK/Vedant/tia_logs/dmc_cheetah_run_driving/tia/')
|
||||
|
||||
|
||||
exit()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def plot_vanilla(parent_dir, tag_filter="train/return", smoothing=0.99):
|
||||
# List to store all EMAs
|
||||
emas = []
|
||||
|
||||
# Traversing through each subfolder in the parent directory
|
||||
for subfolder in os.listdir(parent_dir):
|
||||
json_dir = os.path.join(parent_dir, subfolder)
|
||||
if not os.path.isdir(json_dir):
|
||||
continue
|
||||
# Read each JSON file separately
|
||||
for json_file in os.listdir(json_dir):
|
||||
if not json_file.endswith('.jsonl'):
|
||||
continue
|
||||
# Read the data from the JSON file
|
||||
df = pd.read_json(os.path.join(json_dir, json_file), lines=True)
|
||||
|
||||
# Check if tag_filter exists in DataFrame
|
||||
if tag_filter not in df.columns:
|
||||
continue
|
||||
|
||||
df = df[['step', tag_filter]].sort_values(by='step')
|
||||
|
||||
# Calculate exponential moving average for the smoothing value
|
||||
df['EMA'] = df[tag_filter].ewm(alpha=smoothing, adjust=False).mean()
|
||||
|
||||
# Append the EMA DataFrame to the emas list
|
||||
emas.append(df)
|
||||
|
||||
# Concatenate all EMAs into a single DataFrame and calculate mean and standard deviation
|
||||
all_emas = pd.concat(emas).groupby('step')['EMA']
|
||||
mean_emas = all_emas.mean()
|
||||
std_emas = all_emas.std()
|
||||
|
||||
# Plotting begins here
|
||||
sns.set_style("whitegrid", {'axes.grid' : True, 'axes.edgecolor':'black'})
|
||||
fig = plt.figure()
|
||||
plt.clf()
|
||||
ax = fig.gca()
|
||||
|
||||
# Plot mean and standard deviation of EMAs
|
||||
plt.plot(mean_emas.index, mean_emas, color='blue')
|
||||
plt.fill_between(std_emas.index, (mean_emas-std_emas), (mean_emas+std_emas), color='blue', alpha=.1)
|
||||
|
||||
plt.xlabel('Training Episodes $(\\times10^6)$', fontsize=22)
|
||||
plt.ylabel('Average return', fontsize=22)
|
||||
lgd=plt.legend(frameon=True, fancybox=True, prop={'weight':'bold', 'size':14}, loc="best")
|
||||
#plt.title('Title', fontsize=14)
|
||||
ax = plt.gca()
|
||||
|
||||
plt.setp(ax.get_xticklabels(), fontsize=16)
|
||||
plt.setp(ax.get_yticklabels(), fontsize=16)
|
||||
sns.despine()
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# Call the function
|
||||
plot_vanilla('/media/vedant/cpsDataStorageWK/Vedant/tia_logs/dmc_cheetah_run_driving/tia/')
|
||||
|
||||
|
||||
"""
|
||||
import os
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Set the path to the JSON file
|
||||
parent_dir = '/media/vedant/cpsDataStorageWK/Vedant/tia_logs/dmc_cheetah_run_driving/tia/'
|
||||
|
||||
# Specific tag to filter
|
||||
tag_filter = "train/return"
|
||||
|
||||
# Collect data from all JSON files
|
||||
data = []
|
||||
|
||||
# Smoothing values
|
||||
smoothing = 0.001 # Change num to set the number of smoothing values
|
||||
|
||||
# List to store all EMAs
|
||||
emas = []
|
||||
|
||||
# Traversing through each subfolder in the parent directory
|
||||
for subfolder in os.listdir(parent_dir):
|
||||
json_dir = os.path.join(parent_dir, subfolder)
|
||||
if not os.path.isdir(json_dir):
|
||||
continue
|
||||
# Read each JSON file separately
|
||||
for json_file in os.listdir(json_dir):
|
||||
if not json_file.endswith('.jsonl'):
|
||||
continue
|
||||
# Read the data from the JSON file
|
||||
df = pd.read_json(os.path.join(json_dir, json_file), lines=True)
|
||||
|
||||
# Check if tag_filter exists in DataFrame
|
||||
if tag_filter not in df.columns:
|
||||
continue
|
||||
|
||||
df = df[['step', tag_filter]].sort_values(by='step')
|
||||
|
||||
# Calculate exponential moving average for the smoothing value
|
||||
df['EMA'] = df[tag_filter].ewm(alpha=smoothing, adjust=False).mean()
|
||||
|
||||
# Append the EMA DataFrame to the emas list
|
||||
emas.append(df)
|
||||
|
||||
# Concatenate all EMAs into a single DataFrame and calculate mean and standard deviation
|
||||
all_emas = pd.concat(emas).groupby('step')['EMA']
|
||||
mean_emas = all_emas.mean()
|
||||
std_emas = all_emas.std()
|
||||
|
||||
# Plot mean and standard deviation of EMAs
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(mean_emas.index, mean_emas)
|
||||
plt.fill_between(std_emas.index, (mean_emas-std_emas), (mean_emas+std_emas), color='b', alpha=.1)
|
||||
plt.legend()
|
||||
plt.show()
|
||||
"""
|
339
Dreamer/plotting.py
Normal file
339
Dreamer/plotting.py
Normal file
@ -0,0 +1,339 @@
|
||||
import argparse
|
||||
import collections
|
||||
import functools
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import pathlib
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
# import matplotlib
|
||||
# matplotlib.rcParams['mathtext.fontset'] = 'stix'
|
||||
# matplotlib.rcParams['font.family'] = 'STIXGeneral'
|
||||
|
||||
Run = collections.namedtuple("Run", "task method seed xs ys color")
|
||||
|
||||
PALETTE = 10 * (
|
||||
"#377eb8",
|
||||
"#4daf4a",
|
||||
"#984ea3",
|
||||
"#e41a1c",
|
||||
"#ff7f00",
|
||||
"#a65628",
|
||||
"#f781bf",
|
||||
"#888888",
|
||||
"#a6cee3",
|
||||
"#b2df8a",
|
||||
"#cab2d6",
|
||||
"#fb9a99",
|
||||
"#fdbf6f",
|
||||
)
|
||||
|
||||
LEGEND = dict(
|
||||
fontsize="medium",
|
||||
numpoints=1,
|
||||
labelspacing=0,
|
||||
columnspacing=1.2,
|
||||
handlelength=1.5,
|
||||
handletextpad=0.5,
|
||||
ncol=4,
|
||||
loc="lower center",
|
||||
)
|
||||
|
||||
|
||||
def find_keys(args):
|
||||
filename = next(args.indir[0].glob("**/*.jsonl"))
|
||||
keys = set()
|
||||
for line in filename.read_text().split("\n"):
|
||||
if line:
|
||||
keys |= json.loads(line).keys()
|
||||
print(f"Keys ({len(keys)}):", ", ".join(keys), flush=True)
|
||||
|
||||
|
||||
def load_runs(args):
|
||||
toload = []
|
||||
for indir in args.indir:
|
||||
filenames = list(indir.glob("**/*.jsonl"))
|
||||
for filename in filenames:
|
||||
task, method, seed = filename.relative_to(indir).parts[:-1]
|
||||
if not any(p.search(task) for p in args.tasks):
|
||||
continue
|
||||
if not any(p.search(method) for p in args.methods):
|
||||
continue
|
||||
if method not in args.colors:
|
||||
args.colors[method] = args.palette[len(args.colors)]
|
||||
toload.append((filename, indir))
|
||||
print(f"Loading {len(toload)} of {len(filenames)} runs...")
|
||||
jobs = [functools.partial(load_run, f, i, args) for f, i in toload]
|
||||
with mp.Pool(10) as pool:
|
||||
promises = [pool.apply_async(j) for j in jobs]
|
||||
runs = [p.get() for p in promises]
|
||||
runs = [r for r in runs if r is not None]
|
||||
return runs
|
||||
|
||||
|
||||
def load_run(filename, indir, args):
|
||||
task, method, seed = filename.relative_to(indir).parts[:-1]
|
||||
num_steps = 1000000
|
||||
try:
|
||||
# Future pandas releases will support JSON files with NaN values.
|
||||
# df = pd.read_json(filename, lines=True)
|
||||
with filename.open() as f:
|
||||
data_arr = []
|
||||
for l in f.readlines():
|
||||
data = json.loads(l)
|
||||
data_arr.append(data)
|
||||
if data["step"] > num_steps:
|
||||
break
|
||||
df = pd.DataFrame(data_arr)
|
||||
except ValueError as e:
|
||||
print("Invalid", filename.relative_to(indir), e)
|
||||
return
|
||||
try:
|
||||
df = df[[args.xaxis, args.yaxis]].dropna()
|
||||
except KeyError:
|
||||
return
|
||||
xs = df[args.xaxis].to_numpy()
|
||||
ys = df[args.yaxis].to_numpy()
|
||||
color = args.colors[method]
|
||||
return Run(task, method, seed, xs, ys, color)
|
||||
|
||||
|
||||
def load_baselines(args):
|
||||
runs = []
|
||||
directory = pathlib.Path(__file__).parent / "baselines"
|
||||
for filename in directory.glob("**/*.json"):
|
||||
for task, methods in json.loads(filename.read_text()).items():
|
||||
for method, score in methods.items():
|
||||
if not any(p.search(method) for p in args.baselines):
|
||||
continue
|
||||
if method not in args.colors:
|
||||
args.colors[method] = args.palette[len(args.colors)]
|
||||
color = args.colors[method]
|
||||
runs.append(Run(task, method, None, None, score, color))
|
||||
return runs
|
||||
|
||||
|
||||
def stats(runs):
|
||||
baselines = sorted(set(r.method for r in runs if r.xs is None))
|
||||
runs = [r for r in runs if r.xs is not None]
|
||||
tasks = sorted(set(r.task for r in runs))
|
||||
methods = sorted(set(r.method for r in runs))
|
||||
seeds = sorted(set(r.seed for r in runs))
|
||||
print("Loaded", len(runs), "runs.")
|
||||
print(f"Tasks ({len(tasks)}):", ", ".join(tasks))
|
||||
print(f"Methods ({len(methods)}):", ", ".join(methods))
|
||||
print(f"Seeds ({len(seeds)}):", ", ".join(seeds))
|
||||
print(f"Baselines ({len(baselines)}):", ", ".join(baselines))
|
||||
|
||||
|
||||
def figure(runs, args):
|
||||
tasks = sorted(set(r.task for r in runs if r.xs is not None))
|
||||
rows = int(np.ceil(len(tasks) / args.cols))
|
||||
figsize = args.size[0] * args.cols, args.size[1] * rows
|
||||
fig, axes = plt.subplots(rows, args.cols, figsize=figsize)
|
||||
for task, ax in zip(tasks, axes.flatten()):
|
||||
relevant = [r for r in runs if r.task == task]
|
||||
plot(task, ax, relevant, args)
|
||||
if args.xlim:
|
||||
for ax in axes[:-1].flatten():
|
||||
ax.xaxis.get_offset_text().set_visible(False)
|
||||
if args.xlabel:
|
||||
for ax in axes[-1]:
|
||||
ax.set_xlabel(args.xlabel)
|
||||
if args.ylabel:
|
||||
for ax in axes[:, 0]:
|
||||
ax.set_ylabel(args.ylabel)
|
||||
for ax in axes[len(tasks) :]:
|
||||
ax.axis("off")
|
||||
legend(fig, args.labels, **LEGEND)
|
||||
return fig
|
||||
|
||||
|
||||
def plot(task, ax, runs, args):
|
||||
try:
|
||||
env, task = task.split("_", 1)
|
||||
title = env.capitalize() + " " + task.capitalize()
|
||||
# title = task.split('_', 1)[1].replace('_', ' ').title()
|
||||
except IndexError:
|
||||
title = task.title()
|
||||
ax.set_title(title)
|
||||
methods = []
|
||||
methods += sorted(set(r.method for r in runs if r.xs is not None))
|
||||
methods += sorted(set(r.method for r in runs if r.xs is None))
|
||||
xlim = [+np.inf, -np.inf]
|
||||
for index, method in enumerate(methods):
|
||||
relevant = [r for r in runs if r.method == method]
|
||||
if not relevant:
|
||||
continue
|
||||
if any(r.xs is None for r in relevant):
|
||||
baseline(index, method, ax, relevant, args)
|
||||
else:
|
||||
if args.aggregate == "std":
|
||||
xs, ys = curve_std(index, method, ax, relevant, args)
|
||||
elif args.aggregate == "none":
|
||||
xs, ys = curve_individual(index, method, ax, relevant, args)
|
||||
else:
|
||||
raise NotImplementedError(args.aggregate)
|
||||
xlim = [min(xlim[0], xs.min()), max(xlim[1], xs.max())]
|
||||
ax.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
|
||||
steps = [1, 2, 2.5, 5, 10]
|
||||
ax.xaxis.set_major_locator(ticker.MaxNLocator(args.xticks, steps=steps))
|
||||
ax.yaxis.set_major_locator(ticker.MaxNLocator(args.yticks, steps=steps))
|
||||
ax.set_xlim(args.xlim or xlim)
|
||||
if args.xlim:
|
||||
ticks = sorted({*ax.get_xticks(), *args.xlim})
|
||||
ticks = [x for x in ticks if args.xlim[0] <= x <= args.xlim[1]]
|
||||
ax.set_xticks(ticks)
|
||||
if args.ylim:
|
||||
ax.set_ylim(args.ylim)
|
||||
ticks = sorted({*ax.get_yticks(), *args.ylim})
|
||||
ticks = [x for x in ticks if args.ylim[0] <= x <= args.ylim[1]]
|
||||
ax.set_yticks(ticks)
|
||||
|
||||
|
||||
def curve_individual(index, method, ax, runs, args):
|
||||
if args.bins:
|
||||
for index, run in enumerate(runs):
|
||||
xs, ys = binning(run.xs, run.ys, args.bins, np.nanmean)
|
||||
runs[index] = run._replace(xs=xs, ys=ys)
|
||||
zorder = 10000 - 10 * index - 1
|
||||
for run in runs:
|
||||
ax.plot(run.xs, run.ys, label=method, color=run.color, zorder=zorder)
|
||||
return runs[0].xs, runs[0].ys
|
||||
|
||||
|
||||
def curve_std(index, method, ax, runs, args):
|
||||
if args.bins:
|
||||
for index, run in enumerate(runs):
|
||||
xs, ys = binning(run.xs, run.ys, args.bins, np.nanmean)
|
||||
runs[index] = run._replace(xs=xs, ys=ys)
|
||||
xs = np.concatenate([r.xs for r in runs])
|
||||
ys = np.concatenate([r.ys for r in runs])
|
||||
order = np.argsort(xs)
|
||||
xs, ys = xs[order], ys[order]
|
||||
color = runs[0].color
|
||||
if args.bins:
|
||||
def reducer(y):
|
||||
return (np.nanmean(np.array(y)), np.nanstd(np.array(y)))
|
||||
xs, ys = binning(xs, ys, args.bins, reducer)
|
||||
ys, std = ys.T
|
||||
kw = dict(color=color, zorder=10000 - 10 * index, alpha=0.1, linewidths=0)
|
||||
ax.fill_between(xs, ys - std, ys + std, **kw)
|
||||
ax.plot(xs, ys, label=method, color=color, zorder=10000 - 10 * index - 1)
|
||||
return xs, ys
|
||||
|
||||
|
||||
def baseline(index, method, ax, runs, args):
|
||||
assert len(runs) == 1 and runs[0].xs is None
|
||||
y = np.mean(runs[0].ys)
|
||||
kw = dict(ls="--", color=runs[0].color, zorder=5000 - 10 * index - 1)
|
||||
ax.axhline(y, label=method, **kw)
|
||||
|
||||
|
||||
def binning(xs, ys, bins, reducer):
|
||||
binned_xs = np.arange(xs.min(), xs.max() + 1e-10, bins)
|
||||
binned_ys = []
|
||||
for start, stop in zip([-np.inf] + list(binned_xs), binned_xs):
|
||||
left = (xs <= start).sum()
|
||||
right = (xs <= stop).sum()
|
||||
binned_ys.append(reducer(ys[left:right]))
|
||||
binned_ys = np.array(binned_ys)
|
||||
return binned_xs, binned_ys
|
||||
|
||||
|
||||
def legend(fig, mapping=None, **kwargs):
|
||||
entries = {}
|
||||
for ax in fig.axes:
|
||||
for handle, label in zip(*ax.get_legend_handles_labels()):
|
||||
if mapping and label in mapping:
|
||||
label = mapping[label]
|
||||
entries[label] = handle
|
||||
leg = fig.legend(entries.values(), entries.keys(), **kwargs)
|
||||
leg.get_frame().set_edgecolor("white")
|
||||
extent = leg.get_window_extent(fig.canvas.get_renderer())
|
||||
extent = extent.transformed(fig.transFigure.inverted())
|
||||
yloc, xloc = kwargs["loc"].split()
|
||||
y0 = dict(lower=extent.y1, center=0, upper=0)[yloc]
|
||||
y1 = dict(lower=1, center=1, upper=extent.y0)[yloc]
|
||||
x0 = dict(left=extent.x1, center=0, right=0)[xloc]
|
||||
x1 = dict(left=1, center=1, right=extent.x0)[xloc]
|
||||
fig.tight_layout(rect=[x0, y0, x1, y1], h_pad=0.5, w_pad=0.5)
|
||||
|
||||
|
||||
def save(fig, args):
|
||||
args.outdir.mkdir(parents=True, exist_ok=True)
|
||||
filename = args.outdir / "curves.png"
|
||||
fig.savefig(filename, dpi=130)
|
||||
print("Saved to", filename)
|
||||
filename = args.outdir / "curves.pdf"
|
||||
fig.savefig(filename)
|
||||
try:
|
||||
subprocess.call(["pdfcrop", str(filename), str(filename)])
|
||||
except FileNotFoundError:
|
||||
pass # Install texlive-extra-utils.
|
||||
|
||||
|
||||
def main(args):
|
||||
find_keys(args)
|
||||
runs = load_runs(args) + load_baselines(args)
|
||||
stats(runs)
|
||||
if not runs:
|
||||
print("Noting to plot.")
|
||||
return
|
||||
print("Plotting...")
|
||||
fig = figure(runs, args)
|
||||
save(fig, args)
|
||||
|
||||
|
||||
def parse_args():
|
||||
def boolean(x):
|
||||
return bool(["False", "True"].index(x))
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--indir", nargs="+", type=pathlib.Path, required=True)
|
||||
parser.add_argument("--outdir", type=pathlib.Path, required=True)
|
||||
parser.add_argument("--subdir", type=boolean, default=True)
|
||||
parser.add_argument("--xaxis", type=str, required=True)
|
||||
parser.add_argument("--yaxis", type=str, required=True)
|
||||
parser.add_argument("--tasks", nargs="+", default=[r".*"])
|
||||
parser.add_argument("--methods", nargs="+", default=[r".*"])
|
||||
parser.add_argument("--baselines", nargs="+", default=[])
|
||||
parser.add_argument("--bins", type=float, default=0)
|
||||
parser.add_argument("--aggregate", type=str, default="std")
|
||||
parser.add_argument("--size", nargs=2, type=float, default=[2.5, 2.3])
|
||||
parser.add_argument("--cols", type=int, default=4)
|
||||
parser.add_argument("--xlim", nargs=2, type=float, default=None)
|
||||
parser.add_argument("--ylim", nargs=2, type=float, default=None)
|
||||
parser.add_argument("--xlabel", type=str, default=None)
|
||||
parser.add_argument("--ylabel", type=str, default=None)
|
||||
parser.add_argument("--xticks", type=int, default=6)
|
||||
parser.add_argument("--yticks", type=int, default=5)
|
||||
parser.add_argument("--labels", nargs="+", default=None)
|
||||
parser.add_argument("--palette", nargs="+", default=PALETTE)
|
||||
parser.add_argument("--colors", nargs="+", default={})
|
||||
args = parser.parse_args()
|
||||
if args.subdir:
|
||||
args.outdir /= args.indir[0].stem
|
||||
args.indir = [d.expanduser() for d in args.indir]
|
||||
args.outdir = args.outdir.expanduser()
|
||||
if args.labels:
|
||||
assert len(args.labels) % 2 == 0
|
||||
args.labels = {k: v for k, v in zip(args.labels[:-1], args.labels[1:])}
|
||||
if args.colors:
|
||||
assert len(args.colors) % 2 == 0
|
||||
args.colors = {k: v for k, v in zip(args.colors[:-1], args.colors[1:])}
|
||||
args.tasks = [re.compile(p) for p in args.tasks]
|
||||
args.methods = [re.compile(p) for p in args.methods]
|
||||
args.baselines = [re.compile(p) for p in args.baselines]
|
||||
args.palette = 10 * args.palette
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(parse_args())
|
11
Dreamer/run_all_tia.sh
Executable file
11
Dreamer/run_all_tia.sh
Executable file
@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
|
||||
for seed in 0 1 2
|
||||
do
|
||||
python \
|
||||
run.py \
|
||||
--method tia \
|
||||
--configs dmc \
|
||||
--task dmc_cartpole_swingup_driving \
|
||||
--seed $seed
|
||||
done
|
136
README.md
Normal file
136
README.md
Normal file
@ -0,0 +1,136 @@
|
||||
# Learning Task Informed Abstractions (TIA)
|
||||
|
||||
<sub>Left to right: Raw Observation, Dreamer, Joint of TIA, Task Stream of TIA, Distractor Stream of TIA</sub>
|
||||
|
||||
![](imgs/gt.gif) ![](imgs/pred.gif) ![](imgs/joint.gif) ![](imgs/main.gif) ![](imgs/disen.gif)
|
||||
|
||||
|
||||
This code base contains a minimal modification over [Dreamer](https://danijar.com/project/dreamer/)/[DreamerV2](https://danijar.com/project/dreamerv2/) to learn disentangled world models, presented in:
|
||||
|
||||
**Learning Task Informed Abstractions**
|
||||
|
||||
Xiang Fu*, Ge Yang*, Pulkit Agrawal, Tommi Jaakkola
|
||||
|
||||
ICML 2021 [[website]](https://xiangfu.co/tia) [[paper]](https://arxiv.org/abs/2106.15612)
|
||||
|
||||
|
||||
The directory [Dreamer](./Dreamer) contains code for running DMC experiments. The directory [DreamerV2](./DreamerV2) contains code for running Atari experiments. This implementation is tested with Python 3.6, Tensorflow 2.3.1 and CUDA 10.1. The training/evaluation metrics used for producing the figures in the paper can be downloaded from [this Google Drive link](https://drive.google.com/file/d/1wvSp9Q7r2Ah5xRE_x3nJy-uwLkjF2RgX/view?usp=sharing).
|
||||
|
||||
## Getting started
|
||||
|
||||
Get dependencies:
|
||||
|
||||
```sh
|
||||
pip3 install --user tensorflow-gpu==2.3.1
|
||||
pip3 install --user tensorflow_probability==0.11.0
|
||||
pip3 install --user gym
|
||||
pip3 install --user pandas
|
||||
pip3 install --user matplotlib
|
||||
pip3 install --user ruamel.yaml
|
||||
pip3 install --user scikit-image
|
||||
pip3 install --user git+git://github.com/deepmind/dm_control.git
|
||||
pip3 install --user 'gym[atari]'
|
||||
```
|
||||
You will need an active Mujoco license for running DMC experiments.
|
||||
|
||||
## Running DMC experiments with distracting background
|
||||
|
||||
Code for running DMC experiments is under the directory [Dreamer](./Dreamer).
|
||||
|
||||
To run DMC experiments with distracting video backgrounds, you can download a small set of 16 videos (videos with names starting with ''A'' in the Kinetics 400 dataset's `driving_car` class) from [this Google Drive link](https://drive.google.com/file/d/1f-ER2XnhpvQeGjlJaoGRiLR0oEjn6Le_/view?usp=sharing), which is used for producing Figure 9(a) in the paper's appendix.
|
||||
|
||||
To replicate the setup of [DBC](https://github.com/facebookresearch/deep_bisim4control) and use more background videos, first download the Kinetics 400 dataset and grab the `driving_car` label from the train dataset. Use the repo:
|
||||
|
||||
[https://github.com/Showmax/kinetics-downloader](https://github.com/Showmax/kinetics-downloader)
|
||||
|
||||
to download the dataset.
|
||||
|
||||
Train the agent:
|
||||
|
||||
```sh
|
||||
python run.py --method dreamer --configs dmc --task dmc_cheetah_run_driving --logdir ~/logdir --video_dir VIDPATH
|
||||
```
|
||||
|
||||
`VIDPATH` should contains `*.mp4` video files. (if you used the above repo to download the Kinetics videos, you should set `VIDPATH` to `PATH_TO_REPO/kinetics-downloader/dataset/train/driving_car`)
|
||||
|
||||
|
||||
Choose method from:
|
||||
|
||||
```
|
||||
[dreamer, tia, inverse]
|
||||
```
|
||||
|
||||
corresponding to the original Dreamer, TIA, and representation learned with an inverse model as described in Section 4.2 of the paper.
|
||||
|
||||
|
||||
Choose environment + distraction (e.g. `dmc_cheetah_run_driving`):
|
||||
|
||||
```
|
||||
dmc_{domain}_{task}_{distraction}
|
||||
```
|
||||
|
||||
where {domain} (e.g., cheetah, walker, hopper, etc.) and {task} (e.g., run, walk, stand, etc.) are from the DeepMind Control Suite, and distraction can be chosen from:
|
||||
|
||||
```
|
||||
[none, noise, driving]
|
||||
```
|
||||
|
||||
where each option uses different backgrounds:
|
||||
```
|
||||
none: default (no) background
|
||||
|
||||
noise: white noise background
|
||||
|
||||
driving: natural videos from the ''driving car'' class as background
|
||||
```
|
||||
|
||||
## Running Atari experiments
|
||||
|
||||
Code for running Atari experiments is under the directory [DreamerV2](./DreamerV2).
|
||||
|
||||
Train the agent with the game Demon Attack:
|
||||
|
||||
```sh
|
||||
python dreamer.py --logdir ~/logdir/atari_demon_attack/TIA/1 \
|
||||
--configs defaults atari --task atari_demon_attack
|
||||
```
|
||||
|
||||
## Monitoring results
|
||||
|
||||
Both DMC and Atari experiments log with tensorboard by default. The decomposition of the two streams of TIA is visualized in `.gif` animation. Access tensorboard with the command:
|
||||
|
||||
```sh
|
||||
tensorboard --logdir LOGDIR
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
|
||||
If you find this code useful, please consider citing our paper:
|
||||
|
||||
```
|
||||
@InProceedings{fu2021learning,
|
||||
title = {Learning Task Informed Abstractions},
|
||||
author = {Fu, Xiang and Yang, Ge and Agrawal, Pulkit and Jaakkola, Tommi},
|
||||
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
|
||||
pages = {3480--3491},
|
||||
year = {2021},
|
||||
editor = {Meila, Marina and Zhang, Tong},
|
||||
volume = {139},
|
||||
series = {Proceedings of Machine Learning Research},
|
||||
month = {18--24 Jul},
|
||||
publisher = {PMLR},
|
||||
pdf = {http://proceedings.mlr.press/v139/fu21b/fu21b.pdf},
|
||||
url = {http://proceedings.mlr.press/v139/fu21b.html}
|
||||
}
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
We modify [Dreamer](https://github.com/danijar/dreamer) for DMC environments and [DreamerV2](https://github.com/danijar/dreamerv2) for Atari games. Thanks Danijar for releasing his very clean implementation! Utilities such as
|
||||
|
||||
- Logging with Tensorboard/JSON line files
|
||||
- debugging with the `debug` flag
|
||||
- mixed precision training
|
||||
|
||||
are the same as in the respective original implementations.
|
BIN
imgs/disen.gif
Normal file
BIN
imgs/disen.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 872 KiB |
BIN
imgs/gt.gif
Normal file
BIN
imgs/gt.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 931 KiB |
BIN
imgs/joint.gif
Normal file
BIN
imgs/joint.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 925 KiB |
BIN
imgs/main.gif
Normal file
BIN
imgs/main.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 899 KiB |
BIN
imgs/pred.gif
Normal file
BIN
imgs/pred.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 970 KiB |
Loading…
Reference in New Issue
Block a user