import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
def binning(xs, ys, bins, reducer):
binned_xs = np.arange(xs.min(), xs.max() + 1e-10, bins)
binned_ys = []
for start, stop in zip([-np.inf] + list(binned_xs), binned_xs):
left = (xs <= start).sum()
right = (xs <= stop).sum()
binned_ys = np.array(binned_ys)
return binned_xs, binned_ys
def plot_data(parent_dir, tag_filter="test/return", xaxis='step', value="AverageEpRet", condition="Condition1", smooth=1, bins=30000, xticks=5, yticks=5):
# List to store all DataFrames
data = []
# Traversing through each subfolder in the parent directory
for subfolder in os.listdir(parent_dir):
json_dir = os.path.join(parent_dir, subfolder)
if not os.path.isdir(json_dir):
# Read each JSON file separately
for json_file in os.listdir(json_dir):
if not json_file.endswith('.jsonl'):
# Read the data from the JSON file
df = pd.read_json(os.path.join(json_dir, json_file), lines=True)
# Check if tag_filter exists in DataFrame
if tag_filter not in df.columns:
df = df[['step', tag_filter]].dropna().sort_values(by='step')
# Apply binning
xs, ys = binning(df['step'].to_numpy(), df[tag_filter].to_numpy(), bins, np.nanmean)
# Replace original data with binned data
df = pd.DataFrame({ 'step': xs, tag_filter: ys })
# Append the DataFrame to the list
# Combine all DataFrames
combined_df = pd.concat(data, ignore_index=True)
# Plotting the combined DataFrame
sns.set(style="white", font_scale=1.5)
plot = sns.lineplot(data=combined_df, x=xaxis, y=tag_filter, errorbar='sd')
ax = plot.axes
ax.ticklabel_format(axis="x", scilimits=(5, 5))
steps = [1, 2, 2.5, 5, 10]
ax.xaxis.set_major_locator(ticker.MaxNLocator(xticks, steps=steps))
ax.yaxis.set_major_locator(ticker.MaxNLocator(yticks, steps=steps))
xlim = [+np.inf, -np.inf]
xlim = [min(xlim[0], xs.min()), max(xlim[1], xs.max())]
#plt.xlim([0, max])
# Call the function
def plot_vanilla(parent_dir, tag_filter="train/return", smoothing=0.99):
# List to store all EMAs
emas = []
# Traversing through each subfolder in the parent directory
for subfolder in os.listdir(parent_dir):
json_dir = os.path.join(parent_dir, subfolder)
if not os.path.isdir(json_dir):
# Read each JSON file separately
for json_file in os.listdir(json_dir):
if not json_file.endswith('.jsonl'):
# Read the data from the JSON file
df = pd.read_json(os.path.join(json_dir, json_file), lines=True)
# Check if tag_filter exists in DataFrame
if tag_filter not in df.columns:
df = df[['step', tag_filter]].sort_values(by='step')
# Calculate exponential moving average for the smoothing value
df['EMA'] = df[tag_filter].ewm(alpha=smoothing, adjust=False).mean()
# Append the EMA DataFrame to the emas list
# Concatenate all EMAs into a single DataFrame and calculate mean and standard deviation
all_emas = pd.concat(emas).groupby('step')['EMA']
mean_emas = all_emas.mean()
std_emas = all_emas.std()
# Plotting begins here
sns.set_style("whitegrid", {'axes.grid' : True, 'axes.edgecolor':'black'})
fig = plt.figure()
ax = fig.gca()
# Plot mean and standard deviation of EMAs
plt.plot(mean_emas.index, mean_emas, color='blue')
plt.fill_between(std_emas.index, (mean_emas-std_emas), (mean_emas+std_emas), color='blue', alpha=.1)
plt.xlabel('Training Episodes $(\\times10^6)$', fontsize=22)
plt.ylabel('Average return', fontsize=22)
lgd=plt.legend(frameon=True, fancybox=True, prop={'weight':'bold', 'size':14}, loc="best")
#plt.title('Title', fontsize=14)
ax = plt.gca()
plt.setp(ax.get_xticklabels(), fontsize=16)
plt.setp(ax.get_yticklabels(), fontsize=16)
# Call the function
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# Set the path to the JSON file
parent_dir = '/media/vedant/cpsDataStorageWK/Vedant/tia_logs/dmc_cheetah_run_driving/tia/'
# Specific tag to filter
tag_filter = "train/return"
# Collect data from all JSON files
data = []
# Smoothing values
smoothing = 0.001 # Change num to set the number of smoothing values
# List to store all EMAs
emas = []
# Traversing through each subfolder in the parent directory
for subfolder in os.listdir(parent_dir):
json_dir = os.path.join(parent_dir, subfolder)
if not os.path.isdir(json_dir):
# Read each JSON file separately
for json_file in os.listdir(json_dir):
if not json_file.endswith('.jsonl'):
# Read the data from the JSON file
df = pd.read_json(os.path.join(json_dir, json_file), lines=True)
# Check if tag_filter exists in DataFrame
if tag_filter not in df.columns:
df = df[['step', tag_filter]].sort_values(by='step')
# Calculate exponential moving average for the smoothing value
df['EMA'] = df[tag_filter].ewm(alpha=smoothing, adjust=False).mean()
# Append the EMA DataFrame to the emas list
# Concatenate all EMAs into a single DataFrame and calculate mean and standard deviation
all_emas = pd.concat(emas).groupby('step')['EMA']
mean_emas = all_emas.mean()
std_emas = all_emas.std()
# Plot mean and standard deviation of EMAs
plt.figure(figsize=(10, 6))
plt.plot(mean_emas.index, mean_emas)
plt.fill_between(std_emas.index, (mean_emas-std_emas), (mean_emas+std_emas), color='b', alpha=.1)
import argparse
import collections
import functools
import json
import multiprocessing as mp
import pathlib
import re
import subprocess
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
# import matplotlib
# matplotlib.rcParams['mathtext.fontset'] = 'stix'
# matplotlib.rcParams[''] = 'STIXGeneral'
Run = collections.namedtuple("Run", "task method seed xs ys color")
PALETTE = 10 * (
LEGEND = dict(
loc="lower center",
def find_keys(args):
filename = next(args.indir[0].glob("**/*.jsonl"))
keys = set()
for line in filename.read_text().split("\n"):
if line:
keys |= json.loads(line).keys()
print(f"Keys ({len(keys)}):", ", ".join(keys), flush=True)
def load_runs(args):
toload = []
for indir in args.indir:
filenames = list(indir.glob("**/*.jsonl"))
for filename in filenames:
task, method, seed = filename.relative_to(indir).parts[:-1]
if not any( for p in args.tasks):
if not any( for p in args.methods):
if method not in args.colors:
args.colors[method] = args.palette[len(args.colors)]
toload.append((filename, indir))
print(f"Loading {len(toload)} of {len(filenames)} runs...")
jobs = [functools.partial(load_run, f, i, args) for f, i in toload]
with mp.Pool(10) as pool:
promises = [pool.apply_async(j) for j in jobs]
runs = [p.get() for p in promises]
runs = [r for r in runs if r is not None]
return runs
def load_run(filename, indir, args):
task, method, seed = filename.relative_to(indir).parts[:-1]
num_steps = 1000000
# Future pandas releases will support JSON files with NaN values.
# df = pd.read_json(filename, lines=True)
with as f:
data_arr = []
for l in f.readlines():
data = json.loads(l)
if data["step"] > num_steps:
df = pd.DataFrame(data_arr)
except ValueError as e:
print("Invalid", filename.relative_to(indir), e)
df = df[[args.xaxis, args.yaxis]].dropna()
except KeyError:
xs = df[args.xaxis].to_numpy()
ys = df[args.yaxis].to_numpy()
color = args.colors[method]
return Run(task, method, seed, xs, ys, color)
def load_baselines(args):
runs = []
directory = pathlib.Path(__file__).parent / "baselines"
for filename in directory.glob("**/*.json"):
for task, methods in json.loads(filename.read_text()).items():
for method, score in methods.items():
if not any( for p in args.baselines):
if method not in args.colors:
args.colors[method] = args.palette[len(args.colors)]
color = args.colors[method]
runs.append(Run(task, method, None, None, score, color))
return runs
def stats(runs):
baselines = sorted(set(r.method for r in runs if r.xs is None))
runs = [r for r in runs if r.xs is not None]
tasks = sorted(set(r.task for r in runs))
methods = sorted(set(r.method for r in runs))
seeds = sorted(set(r.seed for r in runs))
print("Loaded", len(runs), "runs.")
print(f"Tasks ({len(tasks)}):", ", ".join(tasks))
print(f"Methods ({len(methods)}):", ", ".join(methods))
print(f"Seeds ({len(seeds)}):", ", ".join(seeds))
print(f"Baselines ({len(baselines)}):", ", ".join(baselines))
def figure(runs, args):
tasks = sorted(set(r.task for r in runs if r.xs is not None))
rows = int(np.ceil(len(tasks) / args.cols))
figsize = args.size[0] * args.cols, args.size[1] * rows
fig, axes = plt.subplots(rows, args.cols, figsize=figsize)
for task, ax in zip(tasks, axes.flatten()):
relevant = [r for r in runs if r.task == task]
plot(task, ax, relevant, args)
if args.xlim:
for ax in axes[:-1].flatten():
if args.xlabel:
for ax in axes[-1]:
if args.ylabel:
for ax in axes[:, 0]:
for ax in axes[len(tasks) :]:
legend(fig, args.labels, **LEGEND)
return fig
def plot(task, ax, runs, args):
env, task = task.split("_", 1)
title = env.capitalize() + " " + task.capitalize()
# title = task.split('_', 1)[1].replace('_', ' ').title()
except IndexError:
title = task.title()
methods = []
methods += sorted(set(r.method for r in runs if r.xs is not None))
methods += sorted(set(r.method for r in runs if r.xs is None))
xlim = [+np.inf, -np.inf]
for index, method in enumerate(methods):
relevant = [r for r in runs if r.method == method]
if not relevant:
if any(r.xs is None for r in relevant):
baseline(index, method, ax, relevant, args)
if args.aggregate == "std":
xs, ys = curve_std(index, method, ax, relevant, args)
elif args.aggregate == "none":
xs, ys = curve_individual(index, method, ax, relevant, args)
raise NotImplementedError(args.aggregate)
xlim = [min(xlim[0], xs.min()), max(xlim[1], xs.max())]
ax.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
steps = [1, 2, 2.5, 5, 10]
ax.xaxis.set_major_locator(ticker.MaxNLocator(args.xticks, steps=steps))
ax.yaxis.set_major_locator(ticker.MaxNLocator(args.yticks, steps=steps))
ax.set_xlim(args.xlim or xlim)
if args.xlim:
ticks = sorted({*ax.get_xticks(), *args.xlim})
ticks = [x for x in ticks if args.xlim[0] <= x <= args.xlim[1]]
if args.ylim:
ticks = sorted({*ax.get_yticks(), *args.ylim})
ticks = [x for x in ticks if args.ylim[0] <= x <= args.ylim[1]]
def curve_individual(index, method, ax, runs, args):
if args.bins:
for index, run in enumerate(runs):
xs, ys = binning(run.xs, run.ys, args.bins, np.nanmean)
runs[index] = run._replace(xs=xs, ys=ys)
zorder = 10000 - 10 * index - 1
for run in runs:
ax.plot(run.xs, run.ys, label=method, color=run.color, zorder=zorder)
return runs[0].xs, runs[0].ys
def curve_std(index, method, ax, runs, args):
if args.bins:
for index, run in enumerate(runs):
xs, ys = binning(run.xs, run.ys, args.bins, np.nanmean)
runs[index] = run._replace(xs=xs, ys=ys)
xs = np.concatenate([r.xs for r in runs])
ys = np.concatenate([r.ys for r in runs])
order = np.argsort(xs)
xs, ys = xs[order], ys[order]
color = runs[0].color
if args.bins:
def reducer(y):
return (np.nanmean(np.array(y)), np.nanstd(np.array(y)))
xs, ys = binning(xs, ys, args.bins, reducer)
ys, std = ys.T
kw = dict(color=color, zorder=10000 - 10 * index, alpha=0.1, linewidths=0)
ax.fill_between(xs, ys - std, ys + std, **kw)
ax.plot(xs, ys, label=method, color=color, zorder=10000 - 10 * index - 1)
return xs, ys
def baseline(index, method, ax, runs, args):
assert len(runs) == 1 and runs[0].xs is None
y = np.mean(runs[0].ys)
kw = dict(ls="--", color=runs[0].color, zorder=5000 - 10 * index - 1)
ax.axhline(y, label=method, **kw)
def binning(xs, ys, bins, reducer):
binned_xs = np.arange(xs.min(), xs.max() + 1e-10, bins)
binned_ys = []
for start, stop in zip([-np.inf] + list(binned_xs), binned_xs):
left = (xs <= start).sum()
right = (xs <= stop).sum()
binned_ys = np.array(binned_ys)
return binned_xs, binned_ys
def legend(fig, mapping=None, **kwargs):
entries = {}
for ax in fig.axes:
for handle, label in zip(*ax.get_legend_handles_labels()):
if mapping and label in mapping:
label = mapping[label]
entries[label] = handle
leg = fig.legend(entries.values(), entries.keys(), **kwargs)
extent = leg.get_window_extent(fig.canvas.get_renderer())
extent = extent.transformed(fig.transFigure.inverted())
yloc, xloc = kwargs["loc"].split()
y0 = dict(lower=extent.y1, center=0, upper=0)[yloc]
y1 = dict(lower=1, center=1, upper=extent.y0)[yloc]
x0 = dict(left=extent.x1, center=0, right=0)[xloc]
x1 = dict(left=1, center=1, right=extent.x0)[xloc]
fig.tight_layout(rect=[x0, y0, x1, y1], h_pad=0.5, w_pad=0.5)
def save(fig, args):
args.outdir.mkdir(parents=True, exist_ok=True)
filename = args.outdir / "curves.png"
fig.savefig(filename, dpi=130)
print("Saved to", filename)
filename = args.outdir / "curves.pdf"
||||["pdfcrop", str(filename), str(filename)])
except FileNotFoundError:
pass # Install texlive-extra-utils.
def main(args):
runs = load_runs(args) + load_baselines(args)
if not runs:
print("Noting to plot.")
fig = figure(runs, args)
save(fig, args)
def parse_args():
def boolean(x):
return bool(["False", "True"].index(x))
parser = argparse.ArgumentParser()
parser.add_argument("--indir", nargs="+", type=pathlib.Path, required=True)
parser.add_argument("--outdir", type=pathlib.Path, required=True)
parser.add_argument("--subdir", type=boolean, default=True)
parser.add_argument("--xaxis", type=str, required=True)
parser.add_argument("--yaxis", type=str, required=True)
parser.add_argument("--tasks", nargs="+", default=[r".*"])
parser.add_argument("--methods", nargs="+", default=[r".*"])
parser.add_argument("--baselines", nargs="+", default=[])
parser.add_argument("--bins", type=float, default=0)
parser.add_argument("--aggregate", type=str, default="std")
parser.add_argument("--size", nargs=2, type=float, default=[2.5, 2.3])
parser.add_argument("--cols", type=int, default=4)
parser.add_argument("--xlim", nargs=2, type=float, default=None)
parser.add_argument("--ylim", nargs=2, type=float, default=None)
parser.add_argument("--xlabel", type=str, default=None)
parser.add_argument("--ylabel", type=str, default=None)
parser.add_argument("--xticks", type=int, default=6)
parser.add_argument("--yticks", type=int, default=5)
parser.add_argument("--labels", nargs="+", default=None)
parser.add_argument("--palette", nargs="+", default=PALETTE)
parser.add_argument("--colors", nargs="+", default={})
args = parser.parse_args()
if args.subdir:
args.outdir /= args.indir[0].stem
args.indir = [d.expanduser() for d in args.indir]
args.outdir = args.outdir.expanduser()
if args.labels:
assert len(args.labels) % 2 == 0
args.labels = {k: v for k, v in zip(args.labels[:-1], args.labels[1:])}
if args.colors:
assert len(args.colors) % 2 == 0
args.colors = {k: v for k, v in zip(args.colors[:-1], args.colors[1:])}
args.tasks = [re.compile(p) for p in args.tasks]
args.methods = [re.compile(p) for p in args.methods]
args.baselines = [re.compile(p) for p in args.baselines]
args.palette = 10 * args.palette
return args
if __name__ == "__main__":
for seed in 0 1 2
python \
|||| \
--method tia \
--configs dmc \
--task dmc_cartpole_swingup_driving \
--seed $seed
# Learning Task Informed Abstractions (TIA)
<sub>Left to right: Raw Observation, Dreamer, Joint of TIA, Task Stream of TIA, Distractor Stream of TIA</sub>
    
This code base contains a minimal modification over [Dreamer]([DreamerV2]( to learn disentangled world models, presented in:
**Learning Task Informed Abstractions**
Xiang Fu*, Ge Yang*, Pulkit Agrawal, Tommi Jaakkola
ICML 2021 [[website]]( [[paper]](
The directory [Dreamer](./Dreamer) contains code for running DMC experiments. The directory [DreamerV2](./DreamerV2) contains code for running Atari experiments. This implementation is tested with Python 3.6, Tensorflow 2.3.1 and CUDA 10.1. The training/evaluation metrics used for producing the figures in the paper can be downloaded from [this Google Drive link](
## Getting started
Get dependencies:
pip3 install --user tensorflow-gpu==2.3.1
pip3 install --user tensorflow_probability==0.11.0
pip3 install --user gym
pip3 install --user pandas
pip3 install --user matplotlib
pip3 install --user ruamel.yaml
pip3 install --user scikit-image
pip3 install --user git+git://
pip3 install --user 'gym[atari]'
You will need an active Mujoco license for running DMC experiments.
## Running DMC experiments with distracting background
Code for running DMC experiments is under the directory [Dreamer](./Dreamer).
To run DMC experiments with distracting video backgrounds, you can download a small set of 16 videos (videos with names starting with ''A'' in the Kinetics 400 dataset's `driving_car` class) from [this Google Drive link](, which is used for producing Figure 9(a) in the paper's appendix.
To replicate the setup of [DBC]( and use more background videos, first download the Kinetics 400 dataset and grab the `driving_car` label from the train dataset. Use the repo:
to download the dataset.
Train the agent:
python --method dreamer --configs dmc --task dmc_cheetah_run_driving --logdir ~/logdir --video_dir VIDPATH
`VIDPATH` should contains `*.mp4` video files. (if you used the above repo to download the Kinetics videos, you should set `VIDPATH` to `PATH_TO_REPO/kinetics-downloader/dataset/train/driving_car`)
Choose method from:
[dreamer, tia, inverse]
corresponding to the original Dreamer, TIA, and representation learned with an inverse model as described in Section 4.2 of the paper.
Choose environment + distraction (e.g. `dmc_cheetah_run_driving`):
where {domain} (e.g., cheetah, walker, hopper, etc.) and {task} (e.g., run, walk, stand, etc.) are from the DeepMind Control Suite, and distraction can be chosen from:
[none, noise, driving]
where each option uses different backgrounds:
none: default (no) background
noise: white noise background
driving: natural videos from the ''driving car'' class as background
## Running Atari experiments
Code for running Atari experiments is under the directory [DreamerV2](./DreamerV2).
Train the agent with the game Demon Attack:
python --logdir ~/logdir/atari_demon_attack/TIA/1 \
--configs defaults atari --task atari_demon_attack
## Monitoring results
Both DMC and Atari experiments log with tensorboard by default. The decomposition of the two streams of TIA is visualized in `.gif` animation. Access tensorboard with the command:
tensorboard --logdir LOGDIR
## Citation
If you find this code useful, please consider citing our paper:
title = {Learning Task Informed Abstractions},
author = {Fu, Xiang and Yang, Ge and Agrawal, Pulkit and Jaakkola, Tommi},
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
pages = {3480--3491},
year = {2021},
editor = {Meila, Marina and Zhang, Tong},
volume = {139},
series = {Proceedings of Machine Learning Research},
month = {18--24 Jul},
publisher = {PMLR},
pdf = {},
url = {}
## Reference
We modify [Dreamer]( for DMC environments and [DreamerV2]( for Atari games. Thanks Danijar for releasing his very clean implementation! Utilities such as
- Logging with Tensorboard/JSON line files
- debugging with the `debug` flag
- mixed precision training
are the same as in the respective original implementations.
