Compare commits
No commits in common. "3b614696817a11ba8cbb637a61bf6c84129a029c" and "0ac3131dad7214c33db645b70be15b3876f83bb1" have entirely different histories.
3b61469681
...
0ac3131dad
5
.gitignore
vendored
5
.gitignore
vendored
@ -1,5 +0,0 @@
|
|||||||
__pycache__/
|
|
||||||
*.py[cod]
|
|
||||||
*.egg-info
|
|
||||||
./dist
|
|
||||||
MUJOCO_LOG.TXT
|
|
@ -1,211 +0,0 @@
|
|||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import seaborn as sns
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import matplotlib.ticker as ticker
|
|
||||||
|
|
||||||
|
|
||||||
def binning(xs, ys, bins, reducer):
|
|
||||||
binned_xs = np.arange(xs.min(), xs.max() + 1e-10, bins)
|
|
||||||
binned_ys = []
|
|
||||||
for start, stop in zip([-np.inf] + list(binned_xs), binned_xs):
|
|
||||||
left = (xs <= start).sum()
|
|
||||||
right = (xs <= stop).sum()
|
|
||||||
binned_ys.append(reducer(ys[left:right]))
|
|
||||||
binned_ys = np.array(binned_ys)
|
|
||||||
return binned_xs, binned_ys
|
|
||||||
|
|
||||||
|
|
||||||
def plot_data(parent_dir, tag_filter="test/return", xaxis='step', value="AverageEpRet", condition="Condition1", smooth=1, bins=30000, xticks=5, yticks=5):
|
|
||||||
# List to store all DataFrames
|
|
||||||
data = []
|
|
||||||
|
|
||||||
# Traversing through each subfolder in the parent directory
|
|
||||||
for subfolder in os.listdir(parent_dir):
|
|
||||||
json_dir = os.path.join(parent_dir, subfolder)
|
|
||||||
if not os.path.isdir(json_dir):
|
|
||||||
continue
|
|
||||||
# Read each JSON file separately
|
|
||||||
for json_file in os.listdir(json_dir):
|
|
||||||
if not json_file.endswith('.jsonl'):
|
|
||||||
continue
|
|
||||||
# Read the data from the JSON file
|
|
||||||
df = pd.read_json(os.path.join(json_dir, json_file), lines=True)
|
|
||||||
|
|
||||||
# Check if tag_filter exists in DataFrame
|
|
||||||
if tag_filter not in df.columns:
|
|
||||||
continue
|
|
||||||
|
|
||||||
df = df[['step', tag_filter]].dropna().sort_values(by='step')
|
|
||||||
|
|
||||||
# Apply binning
|
|
||||||
xs, ys = binning(df['step'].to_numpy(), df[tag_filter].to_numpy(), bins, np.nanmean)
|
|
||||||
|
|
||||||
# Replace original data with binned data
|
|
||||||
df = pd.DataFrame({ 'step': xs, tag_filter: ys })
|
|
||||||
|
|
||||||
# Append the DataFrame to the list
|
|
||||||
data.append(df)
|
|
||||||
|
|
||||||
# Combine all DataFrames
|
|
||||||
combined_df = pd.concat(data, ignore_index=True)
|
|
||||||
|
|
||||||
# Plotting the combined DataFrame
|
|
||||||
sns.set(style="white", font_scale=1.5)
|
|
||||||
plot = sns.lineplot(data=combined_df, x=xaxis, y=tag_filter, errorbar='sd')
|
|
||||||
|
|
||||||
ax = plot.axes
|
|
||||||
ax.ticklabel_format(axis="x", scilimits=(5, 5))
|
|
||||||
steps = [1, 2, 2.5, 5, 10]
|
|
||||||
ax.xaxis.set_major_locator(ticker.MaxNLocator(xticks, steps=steps))
|
|
||||||
ax.yaxis.set_major_locator(ticker.MaxNLocator(yticks, steps=steps))
|
|
||||||
|
|
||||||
xlim = [+np.inf, -np.inf]
|
|
||||||
xlim = [min(xlim[0], xs.min()), max(xlim[1], xs.max())]
|
|
||||||
ax.set_xlim(xlim)
|
|
||||||
#plt.xlim([0, max])
|
|
||||||
|
|
||||||
#plt.legend(loc='best').set_draggable(True)
|
|
||||||
plt.tight_layout(pad=0.5)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
# Call the function
|
|
||||||
plot_data('/media/vedant/cpsDataStorageWK/Vedant/tia_logs/dmc_cheetah_run_driving/tia/')
|
|
||||||
|
|
||||||
|
|
||||||
exit()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def plot_vanilla(parent_dir, tag_filter="train/return", smoothing=0.99):
|
|
||||||
# List to store all EMAs
|
|
||||||
emas = []
|
|
||||||
|
|
||||||
# Traversing through each subfolder in the parent directory
|
|
||||||
for subfolder in os.listdir(parent_dir):
|
|
||||||
json_dir = os.path.join(parent_dir, subfolder)
|
|
||||||
if not os.path.isdir(json_dir):
|
|
||||||
continue
|
|
||||||
# Read each JSON file separately
|
|
||||||
for json_file in os.listdir(json_dir):
|
|
||||||
if not json_file.endswith('.jsonl'):
|
|
||||||
continue
|
|
||||||
# Read the data from the JSON file
|
|
||||||
df = pd.read_json(os.path.join(json_dir, json_file), lines=True)
|
|
||||||
|
|
||||||
# Check if tag_filter exists in DataFrame
|
|
||||||
if tag_filter not in df.columns:
|
|
||||||
continue
|
|
||||||
|
|
||||||
df = df[['step', tag_filter]].sort_values(by='step')
|
|
||||||
|
|
||||||
# Calculate exponential moving average for the smoothing value
|
|
||||||
df['EMA'] = df[tag_filter].ewm(alpha=smoothing, adjust=False).mean()
|
|
||||||
|
|
||||||
# Append the EMA DataFrame to the emas list
|
|
||||||
emas.append(df)
|
|
||||||
|
|
||||||
# Concatenate all EMAs into a single DataFrame and calculate mean and standard deviation
|
|
||||||
all_emas = pd.concat(emas).groupby('step')['EMA']
|
|
||||||
mean_emas = all_emas.mean()
|
|
||||||
std_emas = all_emas.std()
|
|
||||||
|
|
||||||
# Plotting begins here
|
|
||||||
sns.set_style("whitegrid", {'axes.grid' : True, 'axes.edgecolor':'black'})
|
|
||||||
fig = plt.figure()
|
|
||||||
plt.clf()
|
|
||||||
ax = fig.gca()
|
|
||||||
|
|
||||||
# Plot mean and standard deviation of EMAs
|
|
||||||
plt.plot(mean_emas.index, mean_emas, color='blue')
|
|
||||||
plt.fill_between(std_emas.index, (mean_emas-std_emas), (mean_emas+std_emas), color='blue', alpha=.1)
|
|
||||||
|
|
||||||
plt.xlabel('Training Episodes $(\\times10^6)$', fontsize=22)
|
|
||||||
plt.ylabel('Average return', fontsize=22)
|
|
||||||
lgd=plt.legend(frameon=True, fancybox=True, prop={'weight':'bold', 'size':14}, loc="best")
|
|
||||||
#plt.title('Title', fontsize=14)
|
|
||||||
ax = plt.gca()
|
|
||||||
|
|
||||||
plt.setp(ax.get_xticklabels(), fontsize=16)
|
|
||||||
plt.setp(ax.get_yticklabels(), fontsize=16)
|
|
||||||
sns.despine()
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
# Call the function
|
|
||||||
plot_vanilla('/media/vedant/cpsDataStorageWK/Vedant/tia_logs/dmc_cheetah_run_driving/tia/')
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
# Set the path to the JSON file
|
|
||||||
parent_dir = '/media/vedant/cpsDataStorageWK/Vedant/tia_logs/dmc_cheetah_run_driving/tia/'
|
|
||||||
|
|
||||||
# Specific tag to filter
|
|
||||||
tag_filter = "train/return"
|
|
||||||
|
|
||||||
# Collect data from all JSON files
|
|
||||||
data = []
|
|
||||||
|
|
||||||
# Smoothing values
|
|
||||||
smoothing = 0.001 # Change num to set the number of smoothing values
|
|
||||||
|
|
||||||
# List to store all EMAs
|
|
||||||
emas = []
|
|
||||||
|
|
||||||
# Traversing through each subfolder in the parent directory
|
|
||||||
for subfolder in os.listdir(parent_dir):
|
|
||||||
json_dir = os.path.join(parent_dir, subfolder)
|
|
||||||
if not os.path.isdir(json_dir):
|
|
||||||
continue
|
|
||||||
# Read each JSON file separately
|
|
||||||
for json_file in os.listdir(json_dir):
|
|
||||||
if not json_file.endswith('.jsonl'):
|
|
||||||
continue
|
|
||||||
# Read the data from the JSON file
|
|
||||||
df = pd.read_json(os.path.join(json_dir, json_file), lines=True)
|
|
||||||
|
|
||||||
# Check if tag_filter exists in DataFrame
|
|
||||||
if tag_filter not in df.columns:
|
|
||||||
continue
|
|
||||||
|
|
||||||
df = df[['step', tag_filter]].sort_values(by='step')
|
|
||||||
|
|
||||||
# Calculate exponential moving average for the smoothing value
|
|
||||||
df['EMA'] = df[tag_filter].ewm(alpha=smoothing, adjust=False).mean()
|
|
||||||
|
|
||||||
# Append the EMA DataFrame to the emas list
|
|
||||||
emas.append(df)
|
|
||||||
|
|
||||||
# Concatenate all EMAs into a single DataFrame and calculate mean and standard deviation
|
|
||||||
all_emas = pd.concat(emas).groupby('step')['EMA']
|
|
||||||
mean_emas = all_emas.mean()
|
|
||||||
std_emas = all_emas.std()
|
|
||||||
|
|
||||||
# Plot mean and standard deviation of EMAs
|
|
||||||
plt.figure(figsize=(10, 6))
|
|
||||||
plt.plot(mean_emas.index, mean_emas)
|
|
||||||
plt.fill_between(std_emas.index, (mean_emas-std_emas), (mean_emas+std_emas), color='b', alpha=.1)
|
|
||||||
plt.legend()
|
|
||||||
plt.show()
|
|
||||||
"""
|
|
@ -1,339 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import collections
|
|
||||||
import functools
|
|
||||||
import json
|
|
||||||
import multiprocessing as mp
|
|
||||||
import pathlib
|
|
||||||
import re
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import matplotlib.ticker as ticker
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
|
|
||||||
# import matplotlib
|
|
||||||
# matplotlib.rcParams['mathtext.fontset'] = 'stix'
|
|
||||||
# matplotlib.rcParams['font.family'] = 'STIXGeneral'
|
|
||||||
|
|
||||||
Run = collections.namedtuple("Run", "task method seed xs ys color")
|
|
||||||
|
|
||||||
PALETTE = 10 * (
|
|
||||||
"#377eb8",
|
|
||||||
"#4daf4a",
|
|
||||||
"#984ea3",
|
|
||||||
"#e41a1c",
|
|
||||||
"#ff7f00",
|
|
||||||
"#a65628",
|
|
||||||
"#f781bf",
|
|
||||||
"#888888",
|
|
||||||
"#a6cee3",
|
|
||||||
"#b2df8a",
|
|
||||||
"#cab2d6",
|
|
||||||
"#fb9a99",
|
|
||||||
"#fdbf6f",
|
|
||||||
)
|
|
||||||
|
|
||||||
LEGEND = dict(
|
|
||||||
fontsize="medium",
|
|
||||||
numpoints=1,
|
|
||||||
labelspacing=0,
|
|
||||||
columnspacing=1.2,
|
|
||||||
handlelength=1.5,
|
|
||||||
handletextpad=0.5,
|
|
||||||
ncol=4,
|
|
||||||
loc="lower center",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def find_keys(args):
|
|
||||||
filename = next(args.indir[0].glob("**/*.jsonl"))
|
|
||||||
keys = set()
|
|
||||||
for line in filename.read_text().split("\n"):
|
|
||||||
if line:
|
|
||||||
keys |= json.loads(line).keys()
|
|
||||||
print(f"Keys ({len(keys)}):", ", ".join(keys), flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
def load_runs(args):
|
|
||||||
toload = []
|
|
||||||
for indir in args.indir:
|
|
||||||
filenames = list(indir.glob("**/*.jsonl"))
|
|
||||||
for filename in filenames:
|
|
||||||
task, method, seed = filename.relative_to(indir).parts[:-1]
|
|
||||||
if not any(p.search(task) for p in args.tasks):
|
|
||||||
continue
|
|
||||||
if not any(p.search(method) for p in args.methods):
|
|
||||||
continue
|
|
||||||
if method not in args.colors:
|
|
||||||
args.colors[method] = args.palette[len(args.colors)]
|
|
||||||
toload.append((filename, indir))
|
|
||||||
print(f"Loading {len(toload)} of {len(filenames)} runs...")
|
|
||||||
jobs = [functools.partial(load_run, f, i, args) for f, i in toload]
|
|
||||||
with mp.Pool(10) as pool:
|
|
||||||
promises = [pool.apply_async(j) for j in jobs]
|
|
||||||
runs = [p.get() for p in promises]
|
|
||||||
runs = [r for r in runs if r is not None]
|
|
||||||
return runs
|
|
||||||
|
|
||||||
|
|
||||||
def load_run(filename, indir, args):
|
|
||||||
task, method, seed = filename.relative_to(indir).parts[:-1]
|
|
||||||
num_steps = 1000000
|
|
||||||
try:
|
|
||||||
# Future pandas releases will support JSON files with NaN values.
|
|
||||||
# df = pd.read_json(filename, lines=True)
|
|
||||||
with filename.open() as f:
|
|
||||||
data_arr = []
|
|
||||||
for l in f.readlines():
|
|
||||||
data = json.loads(l)
|
|
||||||
data_arr.append(data)
|
|
||||||
if data["step"] > num_steps:
|
|
||||||
break
|
|
||||||
df = pd.DataFrame(data_arr)
|
|
||||||
except ValueError as e:
|
|
||||||
print("Invalid", filename.relative_to(indir), e)
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
df = df[[args.xaxis, args.yaxis]].dropna()
|
|
||||||
except KeyError:
|
|
||||||
return
|
|
||||||
xs = df[args.xaxis].to_numpy()
|
|
||||||
ys = df[args.yaxis].to_numpy()
|
|
||||||
color = args.colors[method]
|
|
||||||
return Run(task, method, seed, xs, ys, color)
|
|
||||||
|
|
||||||
|
|
||||||
def load_baselines(args):
|
|
||||||
runs = []
|
|
||||||
directory = pathlib.Path(__file__).parent / "baselines"
|
|
||||||
for filename in directory.glob("**/*.json"):
|
|
||||||
for task, methods in json.loads(filename.read_text()).items():
|
|
||||||
for method, score in methods.items():
|
|
||||||
if not any(p.search(method) for p in args.baselines):
|
|
||||||
continue
|
|
||||||
if method not in args.colors:
|
|
||||||
args.colors[method] = args.palette[len(args.colors)]
|
|
||||||
color = args.colors[method]
|
|
||||||
runs.append(Run(task, method, None, None, score, color))
|
|
||||||
return runs
|
|
||||||
|
|
||||||
|
|
||||||
def stats(runs):
|
|
||||||
baselines = sorted(set(r.method for r in runs if r.xs is None))
|
|
||||||
runs = [r for r in runs if r.xs is not None]
|
|
||||||
tasks = sorted(set(r.task for r in runs))
|
|
||||||
methods = sorted(set(r.method for r in runs))
|
|
||||||
seeds = sorted(set(r.seed for r in runs))
|
|
||||||
print("Loaded", len(runs), "runs.")
|
|
||||||
print(f"Tasks ({len(tasks)}):", ", ".join(tasks))
|
|
||||||
print(f"Methods ({len(methods)}):", ", ".join(methods))
|
|
||||||
print(f"Seeds ({len(seeds)}):", ", ".join(seeds))
|
|
||||||
print(f"Baselines ({len(baselines)}):", ", ".join(baselines))
|
|
||||||
|
|
||||||
|
|
||||||
def figure(runs, args):
|
|
||||||
tasks = sorted(set(r.task for r in runs if r.xs is not None))
|
|
||||||
rows = int(np.ceil(len(tasks) / args.cols))
|
|
||||||
figsize = args.size[0] * args.cols, args.size[1] * rows
|
|
||||||
fig, axes = plt.subplots(rows, args.cols, figsize=figsize)
|
|
||||||
for task, ax in zip(tasks, axes.flatten()):
|
|
||||||
relevant = [r for r in runs if r.task == task]
|
|
||||||
plot(task, ax, relevant, args)
|
|
||||||
if args.xlim:
|
|
||||||
for ax in axes[:-1].flatten():
|
|
||||||
ax.xaxis.get_offset_text().set_visible(False)
|
|
||||||
if args.xlabel:
|
|
||||||
for ax in axes[-1]:
|
|
||||||
ax.set_xlabel(args.xlabel)
|
|
||||||
if args.ylabel:
|
|
||||||
for ax in axes[:, 0]:
|
|
||||||
ax.set_ylabel(args.ylabel)
|
|
||||||
for ax in axes[len(tasks) :]:
|
|
||||||
ax.axis("off")
|
|
||||||
legend(fig, args.labels, **LEGEND)
|
|
||||||
return fig
|
|
||||||
|
|
||||||
|
|
||||||
def plot(task, ax, runs, args):
|
|
||||||
try:
|
|
||||||
env, task = task.split("_", 1)
|
|
||||||
title = env.capitalize() + " " + task.capitalize()
|
|
||||||
# title = task.split('_', 1)[1].replace('_', ' ').title()
|
|
||||||
except IndexError:
|
|
||||||
title = task.title()
|
|
||||||
ax.set_title(title)
|
|
||||||
methods = []
|
|
||||||
methods += sorted(set(r.method for r in runs if r.xs is not None))
|
|
||||||
methods += sorted(set(r.method for r in runs if r.xs is None))
|
|
||||||
xlim = [+np.inf, -np.inf]
|
|
||||||
for index, method in enumerate(methods):
|
|
||||||
relevant = [r for r in runs if r.method == method]
|
|
||||||
if not relevant:
|
|
||||||
continue
|
|
||||||
if any(r.xs is None for r in relevant):
|
|
||||||
baseline(index, method, ax, relevant, args)
|
|
||||||
else:
|
|
||||||
if args.aggregate == "std":
|
|
||||||
xs, ys = curve_std(index, method, ax, relevant, args)
|
|
||||||
elif args.aggregate == "none":
|
|
||||||
xs, ys = curve_individual(index, method, ax, relevant, args)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(args.aggregate)
|
|
||||||
xlim = [min(xlim[0], xs.min()), max(xlim[1], xs.max())]
|
|
||||||
ax.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
|
|
||||||
steps = [1, 2, 2.5, 5, 10]
|
|
||||||
ax.xaxis.set_major_locator(ticker.MaxNLocator(args.xticks, steps=steps))
|
|
||||||
ax.yaxis.set_major_locator(ticker.MaxNLocator(args.yticks, steps=steps))
|
|
||||||
ax.set_xlim(args.xlim or xlim)
|
|
||||||
if args.xlim:
|
|
||||||
ticks = sorted({*ax.get_xticks(), *args.xlim})
|
|
||||||
ticks = [x for x in ticks if args.xlim[0] <= x <= args.xlim[1]]
|
|
||||||
ax.set_xticks(ticks)
|
|
||||||
if args.ylim:
|
|
||||||
ax.set_ylim(args.ylim)
|
|
||||||
ticks = sorted({*ax.get_yticks(), *args.ylim})
|
|
||||||
ticks = [x for x in ticks if args.ylim[0] <= x <= args.ylim[1]]
|
|
||||||
ax.set_yticks(ticks)
|
|
||||||
|
|
||||||
|
|
||||||
def curve_individual(index, method, ax, runs, args):
|
|
||||||
if args.bins:
|
|
||||||
for index, run in enumerate(runs):
|
|
||||||
xs, ys = binning(run.xs, run.ys, args.bins, np.nanmean)
|
|
||||||
runs[index] = run._replace(xs=xs, ys=ys)
|
|
||||||
zorder = 10000 - 10 * index - 1
|
|
||||||
for run in runs:
|
|
||||||
ax.plot(run.xs, run.ys, label=method, color=run.color, zorder=zorder)
|
|
||||||
return runs[0].xs, runs[0].ys
|
|
||||||
|
|
||||||
|
|
||||||
def curve_std(index, method, ax, runs, args):
|
|
||||||
if args.bins:
|
|
||||||
for index, run in enumerate(runs):
|
|
||||||
xs, ys = binning(run.xs, run.ys, args.bins, np.nanmean)
|
|
||||||
runs[index] = run._replace(xs=xs, ys=ys)
|
|
||||||
xs = np.concatenate([r.xs for r in runs])
|
|
||||||
ys = np.concatenate([r.ys for r in runs])
|
|
||||||
order = np.argsort(xs)
|
|
||||||
xs, ys = xs[order], ys[order]
|
|
||||||
color = runs[0].color
|
|
||||||
if args.bins:
|
|
||||||
def reducer(y):
|
|
||||||
return (np.nanmean(np.array(y)), np.nanstd(np.array(y)))
|
|
||||||
xs, ys = binning(xs, ys, args.bins, reducer)
|
|
||||||
ys, std = ys.T
|
|
||||||
kw = dict(color=color, zorder=10000 - 10 * index, alpha=0.1, linewidths=0)
|
|
||||||
ax.fill_between(xs, ys - std, ys + std, **kw)
|
|
||||||
ax.plot(xs, ys, label=method, color=color, zorder=10000 - 10 * index - 1)
|
|
||||||
return xs, ys
|
|
||||||
|
|
||||||
|
|
||||||
def baseline(index, method, ax, runs, args):
|
|
||||||
assert len(runs) == 1 and runs[0].xs is None
|
|
||||||
y = np.mean(runs[0].ys)
|
|
||||||
kw = dict(ls="--", color=runs[0].color, zorder=5000 - 10 * index - 1)
|
|
||||||
ax.axhline(y, label=method, **kw)
|
|
||||||
|
|
||||||
|
|
||||||
def binning(xs, ys, bins, reducer):
|
|
||||||
binned_xs = np.arange(xs.min(), xs.max() + 1e-10, bins)
|
|
||||||
binned_ys = []
|
|
||||||
for start, stop in zip([-np.inf] + list(binned_xs), binned_xs):
|
|
||||||
left = (xs <= start).sum()
|
|
||||||
right = (xs <= stop).sum()
|
|
||||||
binned_ys.append(reducer(ys[left:right]))
|
|
||||||
binned_ys = np.array(binned_ys)
|
|
||||||
return binned_xs, binned_ys
|
|
||||||
|
|
||||||
|
|
||||||
def legend(fig, mapping=None, **kwargs):
|
|
||||||
entries = {}
|
|
||||||
for ax in fig.axes:
|
|
||||||
for handle, label in zip(*ax.get_legend_handles_labels()):
|
|
||||||
if mapping and label in mapping:
|
|
||||||
label = mapping[label]
|
|
||||||
entries[label] = handle
|
|
||||||
leg = fig.legend(entries.values(), entries.keys(), **kwargs)
|
|
||||||
leg.get_frame().set_edgecolor("white")
|
|
||||||
extent = leg.get_window_extent(fig.canvas.get_renderer())
|
|
||||||
extent = extent.transformed(fig.transFigure.inverted())
|
|
||||||
yloc, xloc = kwargs["loc"].split()
|
|
||||||
y0 = dict(lower=extent.y1, center=0, upper=0)[yloc]
|
|
||||||
y1 = dict(lower=1, center=1, upper=extent.y0)[yloc]
|
|
||||||
x0 = dict(left=extent.x1, center=0, right=0)[xloc]
|
|
||||||
x1 = dict(left=1, center=1, right=extent.x0)[xloc]
|
|
||||||
fig.tight_layout(rect=[x0, y0, x1, y1], h_pad=0.5, w_pad=0.5)
|
|
||||||
|
|
||||||
|
|
||||||
def save(fig, args):
|
|
||||||
args.outdir.mkdir(parents=True, exist_ok=True)
|
|
||||||
filename = args.outdir / "curves.png"
|
|
||||||
fig.savefig(filename, dpi=130)
|
|
||||||
print("Saved to", filename)
|
|
||||||
filename = args.outdir / "curves.pdf"
|
|
||||||
fig.savefig(filename)
|
|
||||||
try:
|
|
||||||
subprocess.call(["pdfcrop", str(filename), str(filename)])
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass # Install texlive-extra-utils.
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
find_keys(args)
|
|
||||||
runs = load_runs(args) + load_baselines(args)
|
|
||||||
stats(runs)
|
|
||||||
if not runs:
|
|
||||||
print("Noting to plot.")
|
|
||||||
return
|
|
||||||
print("Plotting...")
|
|
||||||
fig = figure(runs, args)
|
|
||||||
save(fig, args)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
def boolean(x):
|
|
||||||
return bool(["False", "True"].index(x))
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--indir", nargs="+", type=pathlib.Path, required=True)
|
|
||||||
parser.add_argument("--outdir", type=pathlib.Path, required=True)
|
|
||||||
parser.add_argument("--subdir", type=boolean, default=True)
|
|
||||||
parser.add_argument("--xaxis", type=str, required=True)
|
|
||||||
parser.add_argument("--yaxis", type=str, required=True)
|
|
||||||
parser.add_argument("--tasks", nargs="+", default=[r".*"])
|
|
||||||
parser.add_argument("--methods", nargs="+", default=[r".*"])
|
|
||||||
parser.add_argument("--baselines", nargs="+", default=[])
|
|
||||||
parser.add_argument("--bins", type=float, default=0)
|
|
||||||
parser.add_argument("--aggregate", type=str, default="std")
|
|
||||||
parser.add_argument("--size", nargs=2, type=float, default=[2.5, 2.3])
|
|
||||||
parser.add_argument("--cols", type=int, default=4)
|
|
||||||
parser.add_argument("--xlim", nargs=2, type=float, default=None)
|
|
||||||
parser.add_argument("--ylim", nargs=2, type=float, default=None)
|
|
||||||
parser.add_argument("--xlabel", type=str, default=None)
|
|
||||||
parser.add_argument("--ylabel", type=str, default=None)
|
|
||||||
parser.add_argument("--xticks", type=int, default=6)
|
|
||||||
parser.add_argument("--yticks", type=int, default=5)
|
|
||||||
parser.add_argument("--labels", nargs="+", default=None)
|
|
||||||
parser.add_argument("--palette", nargs="+", default=PALETTE)
|
|
||||||
parser.add_argument("--colors", nargs="+", default={})
|
|
||||||
args = parser.parse_args()
|
|
||||||
if args.subdir:
|
|
||||||
args.outdir /= args.indir[0].stem
|
|
||||||
args.indir = [d.expanduser() for d in args.indir]
|
|
||||||
args.outdir = args.outdir.expanduser()
|
|
||||||
if args.labels:
|
|
||||||
assert len(args.labels) % 2 == 0
|
|
||||||
args.labels = {k: v for k, v in zip(args.labels[:-1], args.labels[1:])}
|
|
||||||
if args.colors:
|
|
||||||
assert len(args.colors) % 2 == 0
|
|
||||||
args.colors = {k: v for k, v in zip(args.colors[:-1], args.colors[1:])}
|
|
||||||
args.tasks = [re.compile(p) for p in args.tasks]
|
|
||||||
args.methods = [re.compile(p) for p in args.methods]
|
|
||||||
args.baselines = [re.compile(p) for p in args.baselines]
|
|
||||||
args.palette = 10 * args.palette
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main(parse_args())
|
|
@ -1,11 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
for seed in 0 1 2
|
|
||||||
do
|
|
||||||
python \
|
|
||||||
run.py \
|
|
||||||
--method tia \
|
|
||||||
--configs dmc \
|
|
||||||
--task dmc_cartpole_swingup_driving \
|
|
||||||
--seed $seed
|
|
||||||
done
|
|
136
README.md
136
README.md
@ -1,136 +0,0 @@
|
|||||||
# Learning Task Informed Abstractions (TIA)
|
|
||||||
|
|
||||||
<sub>Left to right: Raw Observation, Dreamer, Joint of TIA, Task Stream of TIA, Distractor Stream of TIA</sub>
|
|
||||||
|
|
||||||
![](imgs/gt.gif) ![](imgs/pred.gif) ![](imgs/joint.gif) ![](imgs/main.gif) ![](imgs/disen.gif)
|
|
||||||
|
|
||||||
|
|
||||||
This code base contains a minimal modification over [Dreamer](https://danijar.com/project/dreamer/)/[DreamerV2](https://danijar.com/project/dreamerv2/) to learn disentangled world models, presented in:
|
|
||||||
|
|
||||||
**Learning Task Informed Abstractions**
|
|
||||||
|
|
||||||
Xiang Fu*, Ge Yang*, Pulkit Agrawal, Tommi Jaakkola
|
|
||||||
|
|
||||||
ICML 2021 [[website]](https://xiangfu.co/tia) [[paper]](https://arxiv.org/abs/2106.15612)
|
|
||||||
|
|
||||||
|
|
||||||
The directory [Dreamer](./Dreamer) contains code for running DMC experiments. The directory [DreamerV2](./DreamerV2) contains code for running Atari experiments. This implementation is tested with Python 3.6, Tensorflow 2.3.1 and CUDA 10.1. The training/evaluation metrics used for producing the figures in the paper can be downloaded from [this Google Drive link](https://drive.google.com/file/d/1wvSp9Q7r2Ah5xRE_x3nJy-uwLkjF2RgX/view?usp=sharing).
|
|
||||||
|
|
||||||
## Getting started
|
|
||||||
|
|
||||||
Get dependencies:
|
|
||||||
|
|
||||||
```sh
|
|
||||||
pip3 install --user tensorflow-gpu==2.3.1
|
|
||||||
pip3 install --user tensorflow_probability==0.11.0
|
|
||||||
pip3 install --user gym
|
|
||||||
pip3 install --user pandas
|
|
||||||
pip3 install --user matplotlib
|
|
||||||
pip3 install --user ruamel.yaml
|
|
||||||
pip3 install --user scikit-image
|
|
||||||
pip3 install --user git+git://github.com/deepmind/dm_control.git
|
|
||||||
pip3 install --user 'gym[atari]'
|
|
||||||
```
|
|
||||||
You will need an active Mujoco license for running DMC experiments.
|
|
||||||
|
|
||||||
## Running DMC experiments with distracting background
|
|
||||||
|
|
||||||
Code for running DMC experiments is under the directory [Dreamer](./Dreamer).
|
|
||||||
|
|
||||||
To run DMC experiments with distracting video backgrounds, you can download a small set of 16 videos (videos with names starting with ''A'' in the Kinetics 400 dataset's `driving_car` class) from [this Google Drive link](https://drive.google.com/file/d/1f-ER2XnhpvQeGjlJaoGRiLR0oEjn6Le_/view?usp=sharing), which is used for producing Figure 9(a) in the paper's appendix.
|
|
||||||
|
|
||||||
To replicate the setup of [DBC](https://github.com/facebookresearch/deep_bisim4control) and use more background videos, first download the Kinetics 400 dataset and grab the `driving_car` label from the train dataset. Use the repo:
|
|
||||||
|
|
||||||
[https://github.com/Showmax/kinetics-downloader](https://github.com/Showmax/kinetics-downloader)
|
|
||||||
|
|
||||||
to download the dataset.
|
|
||||||
|
|
||||||
Train the agent:
|
|
||||||
|
|
||||||
```sh
|
|
||||||
python run.py --method dreamer --configs dmc --task dmc_cheetah_run_driving --logdir ~/logdir --video_dir VIDPATH
|
|
||||||
```
|
|
||||||
|
|
||||||
`VIDPATH` should contains `*.mp4` video files. (if you used the above repo to download the Kinetics videos, you should set `VIDPATH` to `PATH_TO_REPO/kinetics-downloader/dataset/train/driving_car`)
|
|
||||||
|
|
||||||
|
|
||||||
Choose method from:
|
|
||||||
|
|
||||||
```
|
|
||||||
[dreamer, tia, inverse]
|
|
||||||
```
|
|
||||||
|
|
||||||
corresponding to the original Dreamer, TIA, and representation learned with an inverse model as described in Section 4.2 of the paper.
|
|
||||||
|
|
||||||
|
|
||||||
Choose environment + distraction (e.g. `dmc_cheetah_run_driving`):
|
|
||||||
|
|
||||||
```
|
|
||||||
dmc_{domain}_{task}_{distraction}
|
|
||||||
```
|
|
||||||
|
|
||||||
where {domain} (e.g., cheetah, walker, hopper, etc.) and {task} (e.g., run, walk, stand, etc.) are from the DeepMind Control Suite, and distraction can be chosen from:
|
|
||||||
|
|
||||||
```
|
|
||||||
[none, noise, driving]
|
|
||||||
```
|
|
||||||
|
|
||||||
where each option uses different backgrounds:
|
|
||||||
```
|
|
||||||
none: default (no) background
|
|
||||||
|
|
||||||
noise: white noise background
|
|
||||||
|
|
||||||
driving: natural videos from the ''driving car'' class as background
|
|
||||||
```
|
|
||||||
|
|
||||||
## Running Atari experiments
|
|
||||||
|
|
||||||
Code for running Atari experiments is under the directory [DreamerV2](./DreamerV2).
|
|
||||||
|
|
||||||
Train the agent with the game Demon Attack:
|
|
||||||
|
|
||||||
```sh
|
|
||||||
python dreamer.py --logdir ~/logdir/atari_demon_attack/TIA/1 \
|
|
||||||
--configs defaults atari --task atari_demon_attack
|
|
||||||
```
|
|
||||||
|
|
||||||
## Monitoring results
|
|
||||||
|
|
||||||
Both DMC and Atari experiments log with tensorboard by default. The decomposition of the two streams of TIA is visualized in `.gif` animation. Access tensorboard with the command:
|
|
||||||
|
|
||||||
```sh
|
|
||||||
tensorboard --logdir LOGDIR
|
|
||||||
```
|
|
||||||
|
|
||||||
## Citation
|
|
||||||
|
|
||||||
|
|
||||||
If you find this code useful, please consider citing our paper:
|
|
||||||
|
|
||||||
```
|
|
||||||
@InProceedings{fu2021learning,
|
|
||||||
title = {Learning Task Informed Abstractions},
|
|
||||||
author = {Fu, Xiang and Yang, Ge and Agrawal, Pulkit and Jaakkola, Tommi},
|
|
||||||
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
|
|
||||||
pages = {3480--3491},
|
|
||||||
year = {2021},
|
|
||||||
editor = {Meila, Marina and Zhang, Tong},
|
|
||||||
volume = {139},
|
|
||||||
series = {Proceedings of Machine Learning Research},
|
|
||||||
month = {18--24 Jul},
|
|
||||||
publisher = {PMLR},
|
|
||||||
pdf = {http://proceedings.mlr.press/v139/fu21b/fu21b.pdf},
|
|
||||||
url = {http://proceedings.mlr.press/v139/fu21b.html}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Reference
|
|
||||||
|
|
||||||
We modify [Dreamer](https://github.com/danijar/dreamer) for DMC environments and [DreamerV2](https://github.com/danijar/dreamerv2) for Atari games. Thanks Danijar for releasing his very clean implementation! Utilities such as
|
|
||||||
|
|
||||||
- Logging with Tensorboard/JSON line files
|
|
||||||
- debugging with the `debug` flag
|
|
||||||
- mixed precision training
|
|
||||||
|
|
||||||
are the same as in the respective original implementations.
|
|
BIN
imgs/disen.gif
BIN
imgs/disen.gif
Binary file not shown.
Before Width: | Height: | Size: 872 KiB |
BIN
imgs/gt.gif
BIN
imgs/gt.gif
Binary file not shown.
Before Width: | Height: | Size: 931 KiB |
BIN
imgs/joint.gif
BIN
imgs/joint.gif
Binary file not shown.
Before Width: | Height: | Size: 925 KiB |
BIN
imgs/main.gif
BIN
imgs/main.gif
Binary file not shown.
Before Width: | Height: | Size: 899 KiB |
BIN
imgs/pred.gif
BIN
imgs/pred.gif
Binary file not shown.
Before Width: | Height: | Size: 970 KiB |
Loading…
Reference in New Issue
Block a user