DBC/train_vae.py
2020-10-12 15:39:25 -07:00

97 lines
2.8 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import dmc2gym
import numpy as np
from torch.nn import functional as F
from encoder import make_encoder
from decoder import make_decoder
from sac_ae import weight_init
from train import parse_args
import utils
args = parse_args()
args.domain_name = 'walker'
args.task_name = 'walk'
args.image_size = 84
args.seed = 1
args.agent = 'bisim'
args.encoder_type = 'pixel'
args.action_repeat = 2
args.img_source = 'video'
args.num_layers = 4
args.num_filters = 32
args.hidden_dim = 1024
args.resource_files = '/datasets01/kinetics/070618/400/train/driving_car/*.mp4'
args.total_frames = 5000
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class VAE(nn.Module):
def __init__(self, obs_shape):
super().__init__()
self.encoder = make_encoder(
encoder_type='pixel',
obs_shape=obs_shape,
feature_dim=100,
num_layers=4,
num_filters=32).to(device)
self.decoder = make_decoder(
'pixel', obs_shape, 50, 4, 32).to(device)
self.decoder.apply(weight_init)
def train(self, obs):
h = self.encoder(obs)
mu, log_var = h[:, :50], h[:, 50:]
eps = torch.randn_like(mu)
reparam = mu + torch.exp(log_var / 2) * eps
rec_obs = torch.sigmoid(self.decoder(reparam))
BCE = F.binary_cross_entropy(rec_obs, obs / 255, reduction='sum')
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
loss = BCE + KLD
return loss
env = dmc2gym.make(
domain_name=args.domain_name,
task_name=args.task_name,
resource_files=args.resource_files,
img_source=args.img_source,
total_frames=10,
seed=args.seed,
visualize_reward=False,
from_pixels=(args.encoder_type == 'pixel'),
height=args.image_size,
width=args.image_size,
frame_skip=args.action_repeat
)
env = utils.FrameStack(env, k=args.frame_stack)
vae = VAE(env.observation_space.shape)
train_dataset = torch.load('train_dataset.pt')
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
train_loader = torch.utils.data.DataLoader(train_dataset['obs'], batch_size=32, shuffle=True)
# training loop
for i in range(100):
total_loss = []
for obs_batch in train_loader:
optimizer.zero_grad()
loss = vae.train(obs_batch.to(device).float())
loss.backward()
optimizer.step()
total_loss.append(loss.item())
print(np.mean(total_loss), i)
dataset = torch.load('dataset.pt')
with torch.no_grad():
embeddings = vae.encoder(torch.FloatTensor(dataset['obs']).to(device)).cpu().numpy()
torch.save(embeddings, 'vae_embeddings.pt')