97 lines
2.8 KiB
Python
97 lines
2.8 KiB
Python
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||
|
# All rights reserved.
|
||
|
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import dmc2gym
|
||
|
import numpy as np
|
||
|
from torch.nn import functional as F
|
||
|
|
||
|
from encoder import make_encoder
|
||
|
from decoder import make_decoder
|
||
|
from sac_ae import weight_init
|
||
|
from train import parse_args
|
||
|
import utils
|
||
|
|
||
|
|
||
|
args = parse_args()
|
||
|
args.domain_name = 'walker'
|
||
|
args.task_name = 'walk'
|
||
|
args.image_size = 84
|
||
|
args.seed = 1
|
||
|
args.agent = 'bisim'
|
||
|
args.encoder_type = 'pixel'
|
||
|
args.action_repeat = 2
|
||
|
args.img_source = 'video'
|
||
|
args.num_layers = 4
|
||
|
args.num_filters = 32
|
||
|
args.hidden_dim = 1024
|
||
|
args.resource_files = '/datasets01/kinetics/070618/400/train/driving_car/*.mp4'
|
||
|
args.total_frames = 5000
|
||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
|
||
|
|
||
|
class VAE(nn.Module):
|
||
|
def __init__(self, obs_shape):
|
||
|
super().__init__()
|
||
|
self.encoder = make_encoder(
|
||
|
encoder_type='pixel',
|
||
|
obs_shape=obs_shape,
|
||
|
feature_dim=100,
|
||
|
num_layers=4,
|
||
|
num_filters=32).to(device)
|
||
|
|
||
|
self.decoder = make_decoder(
|
||
|
'pixel', obs_shape, 50, 4, 32).to(device)
|
||
|
self.decoder.apply(weight_init)
|
||
|
|
||
|
def train(self, obs):
|
||
|
h = self.encoder(obs)
|
||
|
mu, log_var = h[:, :50], h[:, 50:]
|
||
|
eps = torch.randn_like(mu)
|
||
|
reparam = mu + torch.exp(log_var / 2) * eps
|
||
|
rec_obs = torch.sigmoid(self.decoder(reparam))
|
||
|
BCE = F.binary_cross_entropy(rec_obs, obs / 255, reduction='sum')
|
||
|
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
|
||
|
loss = BCE + KLD
|
||
|
return loss
|
||
|
|
||
|
|
||
|
env = dmc2gym.make(
|
||
|
domain_name=args.domain_name,
|
||
|
task_name=args.task_name,
|
||
|
resource_files=args.resource_files,
|
||
|
img_source=args.img_source,
|
||
|
total_frames=10,
|
||
|
seed=args.seed,
|
||
|
visualize_reward=False,
|
||
|
from_pixels=(args.encoder_type == 'pixel'),
|
||
|
height=args.image_size,
|
||
|
width=args.image_size,
|
||
|
frame_skip=args.action_repeat
|
||
|
)
|
||
|
env = utils.FrameStack(env, k=args.frame_stack)
|
||
|
vae = VAE(env.observation_space.shape)
|
||
|
train_dataset = torch.load('train_dataset.pt')
|
||
|
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
|
||
|
train_loader = torch.utils.data.DataLoader(train_dataset['obs'], batch_size=32, shuffle=True)
|
||
|
|
||
|
# training loop
|
||
|
for i in range(100):
|
||
|
total_loss = []
|
||
|
for obs_batch in train_loader:
|
||
|
optimizer.zero_grad()
|
||
|
loss = vae.train(obs_batch.to(device).float())
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
total_loss.append(loss.item())
|
||
|
|
||
|
print(np.mean(total_loss), i)
|
||
|
|
||
|
dataset = torch.load('dataset.pt')
|
||
|
with torch.no_grad():
|
||
|
embeddings = vae.encoder(torch.FloatTensor(dataset['obs']).to(device)).cpu().numpy()
|
||
|
torch.save(embeddings, 'vae_embeddings.pt')
|