Compare commits

..

No commits in common. "3b614696817a11ba8cbb637a61bf6c84129a029c" and "0ac3131dad7214c33db645b70be15b3876f83bb1" have entirely different histories.

10 changed files with 0 additions and 702 deletions

5
.gitignore vendored
View File

@ -1,5 +0,0 @@
__pycache__/
*.py[cod]
*.egg-info
./dist
MUJOCO_LOG.TXT

View File

@ -1,211 +0,0 @@
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
def binning(xs, ys, bins, reducer):
binned_xs = np.arange(xs.min(), xs.max() + 1e-10, bins)
binned_ys = []
for start, stop in zip([-np.inf] + list(binned_xs), binned_xs):
left = (xs <= start).sum()
right = (xs <= stop).sum()
binned_ys.append(reducer(ys[left:right]))
binned_ys = np.array(binned_ys)
return binned_xs, binned_ys
def plot_data(parent_dir, tag_filter="test/return", xaxis='step', value="AverageEpRet", condition="Condition1", smooth=1, bins=30000, xticks=5, yticks=5):
# List to store all DataFrames
data = []
# Traversing through each subfolder in the parent directory
for subfolder in os.listdir(parent_dir):
json_dir = os.path.join(parent_dir, subfolder)
if not os.path.isdir(json_dir):
continue
# Read each JSON file separately
for json_file in os.listdir(json_dir):
if not json_file.endswith('.jsonl'):
continue
# Read the data from the JSON file
df = pd.read_json(os.path.join(json_dir, json_file), lines=True)
# Check if tag_filter exists in DataFrame
if tag_filter not in df.columns:
continue
df = df[['step', tag_filter]].dropna().sort_values(by='step')
# Apply binning
xs, ys = binning(df['step'].to_numpy(), df[tag_filter].to_numpy(), bins, np.nanmean)
# Replace original data with binned data
df = pd.DataFrame({ 'step': xs, tag_filter: ys })
# Append the DataFrame to the list
data.append(df)
# Combine all DataFrames
combined_df = pd.concat(data, ignore_index=True)
# Plotting the combined DataFrame
sns.set(style="white", font_scale=1.5)
plot = sns.lineplot(data=combined_df, x=xaxis, y=tag_filter, errorbar='sd')
ax = plot.axes
ax.ticklabel_format(axis="x", scilimits=(5, 5))
steps = [1, 2, 2.5, 5, 10]
ax.xaxis.set_major_locator(ticker.MaxNLocator(xticks, steps=steps))
ax.yaxis.set_major_locator(ticker.MaxNLocator(yticks, steps=steps))
xlim = [+np.inf, -np.inf]
xlim = [min(xlim[0], xs.min()), max(xlim[1], xs.max())]
ax.set_xlim(xlim)
#plt.xlim([0, max])
#plt.legend(loc='best').set_draggable(True)
plt.tight_layout(pad=0.5)
plt.show()
# Call the function
plot_data('/media/vedant/cpsDataStorageWK/Vedant/tia_logs/dmc_cheetah_run_driving/tia/')
exit()
def plot_vanilla(parent_dir, tag_filter="train/return", smoothing=0.99):
# List to store all EMAs
emas = []
# Traversing through each subfolder in the parent directory
for subfolder in os.listdir(parent_dir):
json_dir = os.path.join(parent_dir, subfolder)
if not os.path.isdir(json_dir):
continue
# Read each JSON file separately
for json_file in os.listdir(json_dir):
if not json_file.endswith('.jsonl'):
continue
# Read the data from the JSON file
df = pd.read_json(os.path.join(json_dir, json_file), lines=True)
# Check if tag_filter exists in DataFrame
if tag_filter not in df.columns:
continue
df = df[['step', tag_filter]].sort_values(by='step')
# Calculate exponential moving average for the smoothing value
df['EMA'] = df[tag_filter].ewm(alpha=smoothing, adjust=False).mean()
# Append the EMA DataFrame to the emas list
emas.append(df)
# Concatenate all EMAs into a single DataFrame and calculate mean and standard deviation
all_emas = pd.concat(emas).groupby('step')['EMA']
mean_emas = all_emas.mean()
std_emas = all_emas.std()
# Plotting begins here
sns.set_style("whitegrid", {'axes.grid' : True, 'axes.edgecolor':'black'})
fig = plt.figure()
plt.clf()
ax = fig.gca()
# Plot mean and standard deviation of EMAs
plt.plot(mean_emas.index, mean_emas, color='blue')
plt.fill_between(std_emas.index, (mean_emas-std_emas), (mean_emas+std_emas), color='blue', alpha=.1)
plt.xlabel('Training Episodes $(\\times10^6)$', fontsize=22)
plt.ylabel('Average return', fontsize=22)
lgd=plt.legend(frameon=True, fancybox=True, prop={'weight':'bold', 'size':14}, loc="best")
#plt.title('Title', fontsize=14)
ax = plt.gca()
plt.setp(ax.get_xticklabels(), fontsize=16)
plt.setp(ax.get_yticklabels(), fontsize=16)
sns.despine()
plt.tight_layout()
plt.show()
# Call the function
plot_vanilla('/media/vedant/cpsDataStorageWK/Vedant/tia_logs/dmc_cheetah_run_driving/tia/')
"""
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# Set the path to the JSON file
parent_dir = '/media/vedant/cpsDataStorageWK/Vedant/tia_logs/dmc_cheetah_run_driving/tia/'
# Specific tag to filter
tag_filter = "train/return"
# Collect data from all JSON files
data = []
# Smoothing values
smoothing = 0.001 # Change num to set the number of smoothing values
# List to store all EMAs
emas = []
# Traversing through each subfolder in the parent directory
for subfolder in os.listdir(parent_dir):
json_dir = os.path.join(parent_dir, subfolder)
if not os.path.isdir(json_dir):
continue
# Read each JSON file separately
for json_file in os.listdir(json_dir):
if not json_file.endswith('.jsonl'):
continue
# Read the data from the JSON file
df = pd.read_json(os.path.join(json_dir, json_file), lines=True)
# Check if tag_filter exists in DataFrame
if tag_filter not in df.columns:
continue
df = df[['step', tag_filter]].sort_values(by='step')
# Calculate exponential moving average for the smoothing value
df['EMA'] = df[tag_filter].ewm(alpha=smoothing, adjust=False).mean()
# Append the EMA DataFrame to the emas list
emas.append(df)
# Concatenate all EMAs into a single DataFrame and calculate mean and standard deviation
all_emas = pd.concat(emas).groupby('step')['EMA']
mean_emas = all_emas.mean()
std_emas = all_emas.std()
# Plot mean and standard deviation of EMAs
plt.figure(figsize=(10, 6))
plt.plot(mean_emas.index, mean_emas)
plt.fill_between(std_emas.index, (mean_emas-std_emas), (mean_emas+std_emas), color='b', alpha=.1)
plt.legend()
plt.show()
"""

View File

@ -1,339 +0,0 @@
import argparse
import collections
import functools
import json
import multiprocessing as mp
import pathlib
import re
import subprocess
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
# import matplotlib
# matplotlib.rcParams['mathtext.fontset'] = 'stix'
# matplotlib.rcParams['font.family'] = 'STIXGeneral'
Run = collections.namedtuple("Run", "task method seed xs ys color")
PALETTE = 10 * (
"#377eb8",
"#4daf4a",
"#984ea3",
"#e41a1c",
"#ff7f00",
"#a65628",
"#f781bf",
"#888888",
"#a6cee3",
"#b2df8a",
"#cab2d6",
"#fb9a99",
"#fdbf6f",
)
LEGEND = dict(
fontsize="medium",
numpoints=1,
labelspacing=0,
columnspacing=1.2,
handlelength=1.5,
handletextpad=0.5,
ncol=4,
loc="lower center",
)
def find_keys(args):
filename = next(args.indir[0].glob("**/*.jsonl"))
keys = set()
for line in filename.read_text().split("\n"):
if line:
keys |= json.loads(line).keys()
print(f"Keys ({len(keys)}):", ", ".join(keys), flush=True)
def load_runs(args):
toload = []
for indir in args.indir:
filenames = list(indir.glob("**/*.jsonl"))
for filename in filenames:
task, method, seed = filename.relative_to(indir).parts[:-1]
if not any(p.search(task) for p in args.tasks):
continue
if not any(p.search(method) for p in args.methods):
continue
if method not in args.colors:
args.colors[method] = args.palette[len(args.colors)]
toload.append((filename, indir))
print(f"Loading {len(toload)} of {len(filenames)} runs...")
jobs = [functools.partial(load_run, f, i, args) for f, i in toload]
with mp.Pool(10) as pool:
promises = [pool.apply_async(j) for j in jobs]
runs = [p.get() for p in promises]
runs = [r for r in runs if r is not None]
return runs
def load_run(filename, indir, args):
task, method, seed = filename.relative_to(indir).parts[:-1]
num_steps = 1000000
try:
# Future pandas releases will support JSON files with NaN values.
# df = pd.read_json(filename, lines=True)
with filename.open() as f:
data_arr = []
for l in f.readlines():
data = json.loads(l)
data_arr.append(data)
if data["step"] > num_steps:
break
df = pd.DataFrame(data_arr)
except ValueError as e:
print("Invalid", filename.relative_to(indir), e)
return
try:
df = df[[args.xaxis, args.yaxis]].dropna()
except KeyError:
return
xs = df[args.xaxis].to_numpy()
ys = df[args.yaxis].to_numpy()
color = args.colors[method]
return Run(task, method, seed, xs, ys, color)
def load_baselines(args):
runs = []
directory = pathlib.Path(__file__).parent / "baselines"
for filename in directory.glob("**/*.json"):
for task, methods in json.loads(filename.read_text()).items():
for method, score in methods.items():
if not any(p.search(method) for p in args.baselines):
continue
if method not in args.colors:
args.colors[method] = args.palette[len(args.colors)]
color = args.colors[method]
runs.append(Run(task, method, None, None, score, color))
return runs
def stats(runs):
baselines = sorted(set(r.method for r in runs if r.xs is None))
runs = [r for r in runs if r.xs is not None]
tasks = sorted(set(r.task for r in runs))
methods = sorted(set(r.method for r in runs))
seeds = sorted(set(r.seed for r in runs))
print("Loaded", len(runs), "runs.")
print(f"Tasks ({len(tasks)}):", ", ".join(tasks))
print(f"Methods ({len(methods)}):", ", ".join(methods))
print(f"Seeds ({len(seeds)}):", ", ".join(seeds))
print(f"Baselines ({len(baselines)}):", ", ".join(baselines))
def figure(runs, args):
tasks = sorted(set(r.task for r in runs if r.xs is not None))
rows = int(np.ceil(len(tasks) / args.cols))
figsize = args.size[0] * args.cols, args.size[1] * rows
fig, axes = plt.subplots(rows, args.cols, figsize=figsize)
for task, ax in zip(tasks, axes.flatten()):
relevant = [r for r in runs if r.task == task]
plot(task, ax, relevant, args)
if args.xlim:
for ax in axes[:-1].flatten():
ax.xaxis.get_offset_text().set_visible(False)
if args.xlabel:
for ax in axes[-1]:
ax.set_xlabel(args.xlabel)
if args.ylabel:
for ax in axes[:, 0]:
ax.set_ylabel(args.ylabel)
for ax in axes[len(tasks) :]:
ax.axis("off")
legend(fig, args.labels, **LEGEND)
return fig
def plot(task, ax, runs, args):
try:
env, task = task.split("_", 1)
title = env.capitalize() + " " + task.capitalize()
# title = task.split('_', 1)[1].replace('_', ' ').title()
except IndexError:
title = task.title()
ax.set_title(title)
methods = []
methods += sorted(set(r.method for r in runs if r.xs is not None))
methods += sorted(set(r.method for r in runs if r.xs is None))
xlim = [+np.inf, -np.inf]
for index, method in enumerate(methods):
relevant = [r for r in runs if r.method == method]
if not relevant:
continue
if any(r.xs is None for r in relevant):
baseline(index, method, ax, relevant, args)
else:
if args.aggregate == "std":
xs, ys = curve_std(index, method, ax, relevant, args)
elif args.aggregate == "none":
xs, ys = curve_individual(index, method, ax, relevant, args)
else:
raise NotImplementedError(args.aggregate)
xlim = [min(xlim[0], xs.min()), max(xlim[1], xs.max())]
ax.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
steps = [1, 2, 2.5, 5, 10]
ax.xaxis.set_major_locator(ticker.MaxNLocator(args.xticks, steps=steps))
ax.yaxis.set_major_locator(ticker.MaxNLocator(args.yticks, steps=steps))
ax.set_xlim(args.xlim or xlim)
if args.xlim:
ticks = sorted({*ax.get_xticks(), *args.xlim})
ticks = [x for x in ticks if args.xlim[0] <= x <= args.xlim[1]]
ax.set_xticks(ticks)
if args.ylim:
ax.set_ylim(args.ylim)
ticks = sorted({*ax.get_yticks(), *args.ylim})
ticks = [x for x in ticks if args.ylim[0] <= x <= args.ylim[1]]
ax.set_yticks(ticks)
def curve_individual(index, method, ax, runs, args):
if args.bins:
for index, run in enumerate(runs):
xs, ys = binning(run.xs, run.ys, args.bins, np.nanmean)
runs[index] = run._replace(xs=xs, ys=ys)
zorder = 10000 - 10 * index - 1
for run in runs:
ax.plot(run.xs, run.ys, label=method, color=run.color, zorder=zorder)
return runs[0].xs, runs[0].ys
def curve_std(index, method, ax, runs, args):
if args.bins:
for index, run in enumerate(runs):
xs, ys = binning(run.xs, run.ys, args.bins, np.nanmean)
runs[index] = run._replace(xs=xs, ys=ys)
xs = np.concatenate([r.xs for r in runs])
ys = np.concatenate([r.ys for r in runs])
order = np.argsort(xs)
xs, ys = xs[order], ys[order]
color = runs[0].color
if args.bins:
def reducer(y):
return (np.nanmean(np.array(y)), np.nanstd(np.array(y)))
xs, ys = binning(xs, ys, args.bins, reducer)
ys, std = ys.T
kw = dict(color=color, zorder=10000 - 10 * index, alpha=0.1, linewidths=0)
ax.fill_between(xs, ys - std, ys + std, **kw)
ax.plot(xs, ys, label=method, color=color, zorder=10000 - 10 * index - 1)
return xs, ys
def baseline(index, method, ax, runs, args):
assert len(runs) == 1 and runs[0].xs is None
y = np.mean(runs[0].ys)
kw = dict(ls="--", color=runs[0].color, zorder=5000 - 10 * index - 1)
ax.axhline(y, label=method, **kw)
def binning(xs, ys, bins, reducer):
binned_xs = np.arange(xs.min(), xs.max() + 1e-10, bins)
binned_ys = []
for start, stop in zip([-np.inf] + list(binned_xs), binned_xs):
left = (xs <= start).sum()
right = (xs <= stop).sum()
binned_ys.append(reducer(ys[left:right]))
binned_ys = np.array(binned_ys)
return binned_xs, binned_ys
def legend(fig, mapping=None, **kwargs):
entries = {}
for ax in fig.axes:
for handle, label in zip(*ax.get_legend_handles_labels()):
if mapping and label in mapping:
label = mapping[label]
entries[label] = handle
leg = fig.legend(entries.values(), entries.keys(), **kwargs)
leg.get_frame().set_edgecolor("white")
extent = leg.get_window_extent(fig.canvas.get_renderer())
extent = extent.transformed(fig.transFigure.inverted())
yloc, xloc = kwargs["loc"].split()
y0 = dict(lower=extent.y1, center=0, upper=0)[yloc]
y1 = dict(lower=1, center=1, upper=extent.y0)[yloc]
x0 = dict(left=extent.x1, center=0, right=0)[xloc]
x1 = dict(left=1, center=1, right=extent.x0)[xloc]
fig.tight_layout(rect=[x0, y0, x1, y1], h_pad=0.5, w_pad=0.5)
def save(fig, args):
args.outdir.mkdir(parents=True, exist_ok=True)
filename = args.outdir / "curves.png"
fig.savefig(filename, dpi=130)
print("Saved to", filename)
filename = args.outdir / "curves.pdf"
fig.savefig(filename)
try:
subprocess.call(["pdfcrop", str(filename), str(filename)])
except FileNotFoundError:
pass # Install texlive-extra-utils.
def main(args):
find_keys(args)
runs = load_runs(args) + load_baselines(args)
stats(runs)
if not runs:
print("Noting to plot.")
return
print("Plotting...")
fig = figure(runs, args)
save(fig, args)
def parse_args():
def boolean(x):
return bool(["False", "True"].index(x))
parser = argparse.ArgumentParser()
parser.add_argument("--indir", nargs="+", type=pathlib.Path, required=True)
parser.add_argument("--outdir", type=pathlib.Path, required=True)
parser.add_argument("--subdir", type=boolean, default=True)
parser.add_argument("--xaxis", type=str, required=True)
parser.add_argument("--yaxis", type=str, required=True)
parser.add_argument("--tasks", nargs="+", default=[r".*"])
parser.add_argument("--methods", nargs="+", default=[r".*"])
parser.add_argument("--baselines", nargs="+", default=[])
parser.add_argument("--bins", type=float, default=0)
parser.add_argument("--aggregate", type=str, default="std")
parser.add_argument("--size", nargs=2, type=float, default=[2.5, 2.3])
parser.add_argument("--cols", type=int, default=4)
parser.add_argument("--xlim", nargs=2, type=float, default=None)
parser.add_argument("--ylim", nargs=2, type=float, default=None)
parser.add_argument("--xlabel", type=str, default=None)
parser.add_argument("--ylabel", type=str, default=None)
parser.add_argument("--xticks", type=int, default=6)
parser.add_argument("--yticks", type=int, default=5)
parser.add_argument("--labels", nargs="+", default=None)
parser.add_argument("--palette", nargs="+", default=PALETTE)
parser.add_argument("--colors", nargs="+", default={})
args = parser.parse_args()
if args.subdir:
args.outdir /= args.indir[0].stem
args.indir = [d.expanduser() for d in args.indir]
args.outdir = args.outdir.expanduser()
if args.labels:
assert len(args.labels) % 2 == 0
args.labels = {k: v for k, v in zip(args.labels[:-1], args.labels[1:])}
if args.colors:
assert len(args.colors) % 2 == 0
args.colors = {k: v for k, v in zip(args.colors[:-1], args.colors[1:])}
args.tasks = [re.compile(p) for p in args.tasks]
args.methods = [re.compile(p) for p in args.methods]
args.baselines = [re.compile(p) for p in args.baselines]
args.palette = 10 * args.palette
return args
if __name__ == "__main__":
main(parse_args())

View File

@ -1,11 +0,0 @@
#!/bin/bash
for seed in 0 1 2
do
python \
run.py \
--method tia \
--configs dmc \
--task dmc_cartpole_swingup_driving \
--seed $seed
done

136
README.md
View File

@ -1,136 +0,0 @@
# Learning Task Informed Abstractions (TIA)
<sub>Left to right: Raw Observation, Dreamer, Joint of TIA, Task Stream of TIA, Distractor Stream of TIA</sub>
![](imgs/gt.gif) ![](imgs/pred.gif) ![](imgs/joint.gif) ![](imgs/main.gif) ![](imgs/disen.gif)
This code base contains a minimal modification over [Dreamer](https://danijar.com/project/dreamer/)/[DreamerV2](https://danijar.com/project/dreamerv2/) to learn disentangled world models, presented in:
**Learning Task Informed Abstractions**
Xiang Fu*, Ge Yang*, Pulkit Agrawal, Tommi Jaakkola
ICML 2021 [[website]](https://xiangfu.co/tia) [[paper]](https://arxiv.org/abs/2106.15612)
The directory [Dreamer](./Dreamer) contains code for running DMC experiments. The directory [DreamerV2](./DreamerV2) contains code for running Atari experiments. This implementation is tested with Python 3.6, Tensorflow 2.3.1 and CUDA 10.1. The training/evaluation metrics used for producing the figures in the paper can be downloaded from [this Google Drive link](https://drive.google.com/file/d/1wvSp9Q7r2Ah5xRE_x3nJy-uwLkjF2RgX/view?usp=sharing).
## Getting started
Get dependencies:
```sh
pip3 install --user tensorflow-gpu==2.3.1
pip3 install --user tensorflow_probability==0.11.0
pip3 install --user gym
pip3 install --user pandas
pip3 install --user matplotlib
pip3 install --user ruamel.yaml
pip3 install --user scikit-image
pip3 install --user git+git://github.com/deepmind/dm_control.git
pip3 install --user 'gym[atari]'
```
You will need an active Mujoco license for running DMC experiments.
## Running DMC experiments with distracting background
Code for running DMC experiments is under the directory [Dreamer](./Dreamer).
To run DMC experiments with distracting video backgrounds, you can download a small set of 16 videos (videos with names starting with ''A'' in the Kinetics 400 dataset's `driving_car` class) from [this Google Drive link](https://drive.google.com/file/d/1f-ER2XnhpvQeGjlJaoGRiLR0oEjn6Le_/view?usp=sharing), which is used for producing Figure 9(a) in the paper's appendix.
To replicate the setup of [DBC](https://github.com/facebookresearch/deep_bisim4control) and use more background videos, first download the Kinetics 400 dataset and grab the `driving_car` label from the train dataset. Use the repo:
[https://github.com/Showmax/kinetics-downloader](https://github.com/Showmax/kinetics-downloader)
to download the dataset.
Train the agent:
```sh
python run.py --method dreamer --configs dmc --task dmc_cheetah_run_driving --logdir ~/logdir --video_dir VIDPATH
```
`VIDPATH` should contains `*.mp4` video files. (if you used the above repo to download the Kinetics videos, you should set `VIDPATH` to `PATH_TO_REPO/kinetics-downloader/dataset/train/driving_car`)
Choose method from:
```
[dreamer, tia, inverse]
```
corresponding to the original Dreamer, TIA, and representation learned with an inverse model as described in Section 4.2 of the paper.
Choose environment + distraction (e.g. `dmc_cheetah_run_driving`):
```
dmc_{domain}_{task}_{distraction}
```
where {domain} (e.g., cheetah, walker, hopper, etc.) and {task} (e.g., run, walk, stand, etc.) are from the DeepMind Control Suite, and distraction can be chosen from:
```
[none, noise, driving]
```
where each option uses different backgrounds:
```
none: default (no) background
noise: white noise background
driving: natural videos from the ''driving car'' class as background
```
## Running Atari experiments
Code for running Atari experiments is under the directory [DreamerV2](./DreamerV2).
Train the agent with the game Demon Attack:
```sh
python dreamer.py --logdir ~/logdir/atari_demon_attack/TIA/1 \
--configs defaults atari --task atari_demon_attack
```
## Monitoring results
Both DMC and Atari experiments log with tensorboard by default. The decomposition of the two streams of TIA is visualized in `.gif` animation. Access tensorboard with the command:
```sh
tensorboard --logdir LOGDIR
```
## Citation
If you find this code useful, please consider citing our paper:
```
@InProceedings{fu2021learning,
title = {Learning Task Informed Abstractions},
author = {Fu, Xiang and Yang, Ge and Agrawal, Pulkit and Jaakkola, Tommi},
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
pages = {3480--3491},
year = {2021},
editor = {Meila, Marina and Zhang, Tong},
volume = {139},
series = {Proceedings of Machine Learning Research},
month = {18--24 Jul},
publisher = {PMLR},
pdf = {http://proceedings.mlr.press/v139/fu21b/fu21b.pdf},
url = {http://proceedings.mlr.press/v139/fu21b.html}
}
```
## Reference
We modify [Dreamer](https://github.com/danijar/dreamer) for DMC environments and [DreamerV2](https://github.com/danijar/dreamerv2) for Atari games. Thanks Danijar for releasing his very clean implementation! Utilities such as
- Logging with Tensorboard/JSON line files
- debugging with the `debug` flag
- mixed precision training
are the same as in the respective original implementations.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 872 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 931 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 925 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 899 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 970 KiB