Updating
This commit is contained in:
parent
6a11c9ce61
commit
b3660fe103
429
tac_ssl.py
429
tac_ssl.py
@ -1,118 +1,131 @@
|
|||||||
import os
|
import os
|
||||||
|
import wandb
|
||||||
import pickle
|
import pickle
|
||||||
|
import argparse
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from train_mm_moco import evaluate_and_plot, compute_tsne, MultiModalMoCo
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from util import *
|
||||||
|
from generate_dataset import TouchFolderLabel
|
||||||
|
from linear_classifier import LinearClassifierResNet
|
||||||
|
from train_mm_moco import evaluate_and_plot, compute_tsne, MultiModalMoCo
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torch.utils.data import random_split
|
from torch.utils.data import random_split
|
||||||
|
from torch.optim.lr_scheduler import StepLR
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
writer = SummaryWriter('runs/mmssl')
|
|
||||||
|
|
||||||
# Custom dataset
|
def parse_option():
|
||||||
class CustomMultiModalDataset(Dataset):
|
parser = argparse.ArgumentParser('argument for training')
|
||||||
def __init__(self, vision_folder, tactile_folder, transform=None):
|
|
||||||
self.vision_folder = vision_folder
|
|
||||||
self.tactile_folder = tactile_folder
|
|
||||||
self.transform = transform
|
|
||||||
|
|
||||||
self.vision_files = sorted(os.listdir(vision_folder))
|
parser.add_argument('--print_freq', type=int, default=10, help='print frequency')
|
||||||
self.tactile_files = sorted(os.listdir(tactile_folder))
|
parser.add_argument('--save_freq', type=int, default=10, help='save frequency')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=256, help='batch_size')
|
||||||
|
parser.add_argument('--num_workers', type=int, default=18, help='num of workers to use')
|
||||||
|
parser.add_argument('--epochs', type=int, default=61, help='number of training epochs')
|
||||||
|
parser.add_argument('--num_layers', type=int, default=5, help='number of layers in resnet')
|
||||||
|
|
||||||
def __len__(self):
|
# optimization
|
||||||
return len(self.vision_files)
|
parser.add_argument('--learning_rate', type=float, default=0.03, help='learning rate')
|
||||||
|
parser.add_argument('--lr_decay_epochs', type=str, default='120,160', help='where to decay lr, can be a list')
|
||||||
|
parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
|
||||||
|
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam')
|
||||||
|
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam')
|
||||||
|
parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay')
|
||||||
|
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
# resume path
|
||||||
vision_path = os.path.join(self.vision_folder, self.vision_files[idx])
|
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||||
tactile_path = os.path.join(self.tactile_folder, self.tactile_files[idx])
|
help='path to latest checkpoint (default: none)')
|
||||||
|
|
||||||
vision_image = Image.open(vision_path).convert("RGB")
|
# model definition
|
||||||
tactile_image = Image.open(tactile_path).convert("RGB")
|
parser.add_argument('--model', type=str, default='alexnet', choices=[
|
||||||
|
'resnet50t1', 'resnet101t1', 'resnet18t1',
|
||||||
|
'resnet50t2', 'resnet101t2', 'resnet18t2',
|
||||||
|
'resnet50t3', 'resnet101t3', 'resnet18t3'])
|
||||||
|
parser.add_argument('--softmax', action='store_true', help='using softmax contrastive loss rather than NCE')
|
||||||
|
parser.add_argument('--feat_dim', type=int, default=128, help='dim of feat for inner product')
|
||||||
|
|
||||||
if self.transform:
|
# dataset
|
||||||
vision_image = self.transform(vision_image)
|
parser.add_argument('--dataset', type=str, default='touch_and_go', choices=['touch_and_go', 'pretrain', 'touch_rough', 'touch_hard'])
|
||||||
tactile_image = self.transform(tactile_image)
|
|
||||||
|
|
||||||
return vision_image, tactile_image
|
# specify folder
|
||||||
|
parser.add_argument('--data_folder', type=str, default="dataset/", help='path to dataset')
|
||||||
|
parser.add_argument('--data_loader', type=str, default='touch_and_go', choices=['touch_and_go'])
|
||||||
|
parser.add_argument('--model_path', type=str, default="ckpt/mmssl", help='path to save model')
|
||||||
|
|
||||||
# Initialize augmentation
|
# add new views
|
||||||
simple_transforms = transforms.Compose([
|
parser.add_argument('--view', type=str, default='Touch', choices=['Touch'])
|
||||||
transforms.Resize((275, 275)),
|
|
||||||
#transforms.CenterCrop(500),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
||||||
])
|
|
||||||
|
|
||||||
data_transforms = transforms.Compose([
|
# mixed precision setting
|
||||||
transforms.RandomApply([transforms.RandomRotation(150)], p=0.50),
|
parser.add_argument('--amp', action='store_true', help='using mixed precision')
|
||||||
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
|
parser.add_argument('--opt_level', type=str, default='O2', choices=['O1', 'O2'])
|
||||||
transforms.RandomApply([transforms.RandomHorizontalFlip()], p=0.50),
|
|
||||||
#transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
|
|
||||||
transforms.RandomGrayscale(p=0.2),
|
|
||||||
#transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.5),
|
|
||||||
])
|
|
||||||
|
|
||||||
# Initialize dataset and dataloader
|
# data crop threshold
|
||||||
vision_folder = "/home/vedant/Downloads/ssvtp_data/images_rgb"
|
parser.add_argument('--crop_low', type=float, default=0.2, help='low area in crop')
|
||||||
tactile_folder = "/home/vedant/Downloads/ssvtp_data/images_tac"
|
|
||||||
dataset = CustomMultiModalDataset(vision_folder, tactile_folder, transform=simple_transforms)
|
|
||||||
|
|
||||||
preload = True
|
# data amount
|
||||||
if not preload:
|
parser.add_argument('--data_amount', type=int, default=100, help='how much data used')
|
||||||
# Split the dataset into 80-20
|
parser.add_argument('--comment', type=str, default='', help='comment')
|
||||||
train_size = int(0.8 * len(dataset))
|
|
||||||
test_size = len(dataset) - train_size
|
|
||||||
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
|
|
||||||
|
|
||||||
# Get the indices of the training and test sets
|
# wandb
|
||||||
train_indices = train_dataset.indices
|
parser.add_argument('--wandb', action='store_true', help='Enable wandb')
|
||||||
test_indices = test_dataset.indices
|
parser.add_argument('--wandb_name', type=str, default=None, help='username of wandb')
|
||||||
|
|
||||||
# Save these indices to disk
|
opt = parser.parse_args()
|
||||||
with open('indices/train_indices.pkl', 'wb') as f:
|
|
||||||
pickle.dump(train_indices, f)
|
|
||||||
|
|
||||||
with open('indices/test_indices.pkl', 'wb') as f:
|
|
||||||
pickle.dump(test_indices, f)
|
|
||||||
|
|
||||||
# Initialize dataloaders for train and test
|
if (opt.data_folder is None) or (opt.model_path is None):
|
||||||
train_dataloader = DataLoader(train_dataset, batch_size=96, shuffle=True)
|
raise ValueError('one or more of the folders is None: data_folder | model_path')
|
||||||
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
|
||||||
else:
|
return opt
|
||||||
# Load the indices from disk
|
|
||||||
with open('indices/train_indices.pkl', 'rb') as f:
|
|
||||||
train_indices = pickle.load(f)
|
|
||||||
|
|
||||||
with open('indices/test_indices.pkl', 'rb') as f:
|
|
||||||
test_indices = pickle.load(f)
|
|
||||||
|
|
||||||
# Create subset datasets and DataLoaders
|
|
||||||
train_subset = torch.utils.data.Subset(dataset, train_indices)
|
|
||||||
test_subset = torch.utils.data.Subset(dataset, test_indices)
|
|
||||||
|
|
||||||
train_dataloader = DataLoader(train_subset, batch_size=96, shuffle=True)
|
def logging(epoch, idx, train_loader_len, loss, acc1, acc5, losses, top1, top5, pretrain, train=True):
|
||||||
test_dataloader = DataLoader(test_subset, batch_size=32, shuffle=False)
|
if pretrain:
|
||||||
|
print('Epoch: [{0}][{1}/{2}]\t'
|
||||||
|
'Loss {loss:.4f}'.format(
|
||||||
|
epoch, idx, train_loader_len, loss=loss))
|
||||||
|
wandb.log({"training loss": loss}, step=epoch * train_loader_len + idx)
|
||||||
|
else:
|
||||||
|
print('Epoch: [{0}][{1}/{2}]\t'
|
||||||
|
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
|
||||||
|
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
|
||||||
|
epoch, idx, train_loader_len, loss=losses, top1=top1))
|
||||||
|
if train:
|
||||||
|
wandb.log({"training loss": loss,
|
||||||
|
"training accuracy": acc1[0],
|
||||||
|
"training top5 accuracy": acc5[0]},
|
||||||
|
step=epoch)
|
||||||
|
else:
|
||||||
|
wandb.log({"validation loss": loss,
|
||||||
|
"validation accuracy": acc1[0],
|
||||||
|
"validation top5 accuracy": acc5[0]},
|
||||||
|
step=epoch)
|
||||||
|
|
||||||
# Initialize model
|
def train(epoch, train_loader, model, optimizer, classifier=None, criterion=None, task=None, scheduler=None):
|
||||||
model = MultiModalMoCo(writer, K=4096, m=0.99, T=0.07).to(device)
|
# Logging setup
|
||||||
|
losses = AverageMeter()
|
||||||
|
top1 = AverageMeter()
|
||||||
|
top5 = AverageMeter()
|
||||||
|
|
||||||
# Initialize optimizer
|
# Training loop
|
||||||
vision_module = list(model.vision_base_q.parameters()) + list(model.vision_head_intra_q.parameters()) + list(model.vision_head_inter_q.parameters())
|
for idx, values in enumerate(train_loader):
|
||||||
tactile_module = list(model.tactile_base_q.parameters()) + list(model.tactile_head_intra_q.parameters()) + list(model.tactile_head_inter_q.parameters())
|
if classifier is None:
|
||||||
optim_vision = optim.Adam(vision_module, lr=0.1)
|
inputs, _, index = values
|
||||||
optim_tactile = optim.Adam(tactile_module, lr=0.1)
|
model.train()
|
||||||
|
else:
|
||||||
|
inputs, target = values
|
||||||
|
classifier = classifier.to(device)
|
||||||
|
criterion = criterion.to(device)
|
||||||
|
model.eval()
|
||||||
|
classifier.train()
|
||||||
|
|
||||||
# Training loop
|
x_vision, x_tactile = inputs[:,:3,:,:], inputs[:,3:,:,:]
|
||||||
n_epochs = 500 # Number of epochs
|
|
||||||
for epoch in range(n_epochs):
|
|
||||||
for i, (x_vision, x_tactile) in enumerate(train_dataloader):
|
|
||||||
|
|
||||||
# Augment images
|
# Augment images
|
||||||
x_vision_q = data_transforms(x_vision).to(device)
|
x_vision_q = data_transforms(x_vision).to(device)
|
||||||
@ -122,22 +135,240 @@ for epoch in range(n_epochs):
|
|||||||
x_tactile_k = data_transforms(x_tactile).to(device)
|
x_tactile_k = data_transforms(x_tactile).to(device)
|
||||||
|
|
||||||
# Forward pass to get the loss
|
# Forward pass to get the loss
|
||||||
loss = model(x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, i, len(train_dataloader))
|
if classifier is None:
|
||||||
|
loss = model(x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, idx, len(train_loader))
|
||||||
|
|
||||||
# Backward pass and optimization
|
# Backward pass and optimization
|
||||||
optim_vision.zero_grad()
|
optimizer.zero_grad()
|
||||||
optim_tactile.zero_grad()
|
loss.backward()
|
||||||
loss.backward()
|
optimizer.step()
|
||||||
optim_vision.step()
|
#scheduler.step()
|
||||||
optim_tactile.step()
|
|
||||||
|
|
||||||
# Logging
|
# For logging
|
||||||
if i % 10 == 0:
|
acc1, acc5, losses, top1, top5 = 0, 0, 0, 0, 0
|
||||||
print(f"Epoch [{epoch+1}/{n_epochs}], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}")
|
else:
|
||||||
writer.add_scalar('training loss', loss.item(), epoch * len(train_dataloader) + i)
|
feat = model.tactile_base_q(x_tactile_q)
|
||||||
|
if args.num_layers == 3:
|
||||||
|
feat = model.phi_vision_q(feat)
|
||||||
|
|
||||||
# Evaluate and plot
|
output = classifier(feat)
|
||||||
#compute_tsne(model, test_dataloader, writer, epoch)
|
target = target.cuda()
|
||||||
#evaluate_and_plot(model, test_dataloader, epoch, writer, device)
|
loss = criterion(output, target)
|
||||||
if epoch % 10 == 0:
|
|
||||||
torch.save(model.state_dict(), 'models/model.pth')
|
acc1, acc5 = accuracy(output, target, topk=(1, 1))
|
||||||
|
losses.update(loss.item(), inputs.size(0))
|
||||||
|
top1.update(acc1[0], inputs.size(0))
|
||||||
|
top5.update(acc5[0], inputs.size(0))
|
||||||
|
|
||||||
|
# Backward pass and optimization
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
#scheduler.step()
|
||||||
|
|
||||||
|
if idx % 100 == 0 and idx!=0:
|
||||||
|
logging(epoch, idx, len(train_loader), loss.item(), acc1, acc5, losses, top1, top5, pretrain=(classifier is None))
|
||||||
|
|
||||||
|
|
||||||
|
def val(epoch, test_loader, model, optimizer, classifier=None, criterion=None, task=None):
|
||||||
|
# Logging setup
|
||||||
|
losses = AverageMeter()
|
||||||
|
top1 = AverageMeter()
|
||||||
|
top5 = AverageMeter()
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
for idx, values in enumerate(test_loader):
|
||||||
|
if classifier is None:
|
||||||
|
inputs, _, index = values
|
||||||
|
model.eval()
|
||||||
|
else:
|
||||||
|
inputs, target = values
|
||||||
|
classifier = classifier.to(device)
|
||||||
|
criterion = criterion.to(device)
|
||||||
|
model.eval()
|
||||||
|
classifier.eval()
|
||||||
|
|
||||||
|
x_vision, x_tactile = inputs[:,:3,:,:], inputs[:,3:,:,:]
|
||||||
|
|
||||||
|
# Augment images
|
||||||
|
x_vision_q = data_transforms(x_vision).to(device)
|
||||||
|
x_vision_k = data_transforms(x_vision).to(device)
|
||||||
|
|
||||||
|
x_tactile_q = data_transforms(x_tactile).to(device)
|
||||||
|
x_tactile_k = data_transforms(x_tactile).to(device)
|
||||||
|
|
||||||
|
# Forward pass to get the loss
|
||||||
|
if classifier is None:
|
||||||
|
loss = model(x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, idx, len(test_loader))
|
||||||
|
|
||||||
|
# For logging
|
||||||
|
acc1, acc5, losses, top1, top5 = 0, 0, 0, 0, 0
|
||||||
|
else:
|
||||||
|
feat = model.tactile_base_q(x_tactile_q)
|
||||||
|
if args.num_layers == 3:
|
||||||
|
feat = model.phi_vision_q(feat)
|
||||||
|
|
||||||
|
output = classifier(feat)
|
||||||
|
target = target.cuda()
|
||||||
|
loss = criterion(output, target)
|
||||||
|
|
||||||
|
acc1, acc5 = accuracy(output, target, topk=(1, 1))
|
||||||
|
losses.update(loss.item(), inputs.size(0))
|
||||||
|
top1.update(acc1[0], inputs.size(0))
|
||||||
|
top5.update(acc5[0], inputs.size(0))
|
||||||
|
|
||||||
|
|
||||||
|
if len(test_loader) > 100:
|
||||||
|
if idx % 100 == 0 and idx!=0:
|
||||||
|
logging(epoch, idx, len(test_loader), loss.item(), acc1, acc5, losses, top1, top5, pretrain=(classifier is None), train=False)
|
||||||
|
wandb.log({"val accuracy": top1.avg.item()}, step=epoch)
|
||||||
|
print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
|
||||||
|
else:
|
||||||
|
if idx % 30 == 0 and idx!=0:
|
||||||
|
logging(epoch, idx, len(test_loader), loss.item(), acc1, acc5, losses, top1, top5, pretrain=(classifier is None), train=False)
|
||||||
|
#wandb.log({"val accuracy": top1.avg.item()}, step=epoch)
|
||||||
|
print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
|
||||||
|
|
||||||
|
if classifier is None:
|
||||||
|
return 0, 0, 0
|
||||||
|
|
||||||
|
return top1.avg, top5.avg, losses.avg
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Best Accuracy
|
||||||
|
global best_acc1
|
||||||
|
best_acc1 = 0
|
||||||
|
|
||||||
|
# Parse arguments
|
||||||
|
args = parse_option()
|
||||||
|
|
||||||
|
# Initialize wandb
|
||||||
|
wandb.login()
|
||||||
|
#wandb.init(project='tac_tag')
|
||||||
|
|
||||||
|
# Initialize augmentation
|
||||||
|
simple_transforms = transforms.Compose([
|
||||||
|
transforms.Resize((256, 256)),
|
||||||
|
#transforms.CenterCrop(500),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
|
||||||
|
data_transforms = transforms.Compose([
|
||||||
|
#transforms.RandomApply([transforms.RandomRotation(150)], p=0.50),
|
||||||
|
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
|
||||||
|
transforms.RandomApply([transforms.RandomHorizontalFlip()], p=0.50),
|
||||||
|
#transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
|
||||||
|
transforms.RandomGrayscale(p=0.2),
|
||||||
|
#transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.5),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Initialize dataset and dataloader
|
||||||
|
data_folder = args.data_folder
|
||||||
|
train_sampler = None
|
||||||
|
if args.dataset == 'touch_and_go' or args.dataset == 'touch_rough' or args.dataset == 'touch_hard' or args.dataset == 'pretrain':
|
||||||
|
if args.dataset == 'touch_hard':
|
||||||
|
print('hard')
|
||||||
|
train_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='train', label='hard')
|
||||||
|
val_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='test', label='hard')
|
||||||
|
n_labels = 2
|
||||||
|
elif args.dataset == 'touch_rough':
|
||||||
|
print('rough')
|
||||||
|
train_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='train', label='rough')
|
||||||
|
val_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='test', label='rough')
|
||||||
|
n_labels = 2
|
||||||
|
elif args.dataset == 'touch_and_go':
|
||||||
|
train_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='train')
|
||||||
|
val_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='test')
|
||||||
|
n_labels = 20
|
||||||
|
elif args.dataset == 'pretrain':
|
||||||
|
train_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='pretrain')
|
||||||
|
val_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='pretrain')
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('dataset not supported {}'.format(args.dataset))
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
||||||
|
num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)
|
||||||
|
|
||||||
|
test_loader = torch.utils.data.DataLoader(
|
||||||
|
val_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
||||||
|
num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)
|
||||||
|
|
||||||
|
# num of samples
|
||||||
|
n_data = len(train_dataset)
|
||||||
|
print('number of samples: {}'.format(n_data))
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
nn_model = 'resnet18'
|
||||||
|
model = MultiModalMoCo(m=0.99, T=0.07, nn_model=nn_model).to(device)
|
||||||
|
|
||||||
|
# Initialize task
|
||||||
|
if args.dataset == 'pretrain':
|
||||||
|
task = "pretrain"
|
||||||
|
elif args.dataset == 'touch_and_go':
|
||||||
|
task = "material"
|
||||||
|
elif args.dataset == 'touch_rough':
|
||||||
|
task = "rough"
|
||||||
|
elif args.dataset == 'touch_hard':
|
||||||
|
task = "hard"
|
||||||
|
|
||||||
|
if task == "pretrain":
|
||||||
|
# Initialize optimizer
|
||||||
|
modules = list(model.vision_base_q.parameters()) + list(model.tactile_base_q.parameters()) + list(model.phi_vision_q.parameters()) + list(model.phi_tactile_q.parameters())
|
||||||
|
optimizer = optim.Adam(modules, lr=0.03)
|
||||||
|
classifier, criterion = None, None
|
||||||
|
else:
|
||||||
|
# Initialize training
|
||||||
|
#model.load_state_dict(torch.load('/home/cpsadmin/TAG/models/model_tag_238.pth'))
|
||||||
|
model.load_state_dict(torch.load('/media/vedant/cpsDataStorageWK/Vedant/Tactile/TAG/models/model_tag_pretrain_220.pth')['model'])
|
||||||
|
classifier = LinearClassifierResNet(layer=args.num_layers, n_label=n_labels)
|
||||||
|
optimizer = optim.Adam(classifier.parameters(), lr=1e-4)
|
||||||
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize wandb
|
||||||
|
wandb.init(project=task+"_mm")
|
||||||
|
|
||||||
|
# Train
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
train(epoch, train_loader, model, optimizer, classifier=classifier, criterion=criterion, task=task)
|
||||||
|
|
||||||
|
if task == "pretrain":
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print("==> testing...")
|
||||||
|
test_acc, test_acc5, test_loss = val(epoch, test_loader, model, optimizer, classifier=classifier, criterion=criterion, task=task)
|
||||||
|
|
||||||
|
# save the best model
|
||||||
|
print('test_acc: {}'.format(test_acc))
|
||||||
|
print('best_acc1: {}'.format(best_acc1))
|
||||||
|
if test_acc > best_acc1:
|
||||||
|
best_acc1 = test_acc
|
||||||
|
state = {
|
||||||
|
'opt': args,
|
||||||
|
'epoch': epoch,
|
||||||
|
'classifier': classifier.state_dict(),
|
||||||
|
'best_acc1': best_acc1,
|
||||||
|
'optimizer': optimizer.state_dict(),
|
||||||
|
}
|
||||||
|
save_name = f'models/model_tag_{task}_best.pth'
|
||||||
|
print('saving best model!')
|
||||||
|
torch.save(state, save_name)
|
||||||
|
|
||||||
|
if epoch % 10 == 0:
|
||||||
|
print('==> Saving...')
|
||||||
|
classifier_ = classifier if task != "pretrain" else optimizer
|
||||||
|
optimizer_ = optimizer
|
||||||
|
best_acc1 = best_acc1 if task != "pretrain" else 0
|
||||||
|
state = {
|
||||||
|
'opt': args,
|
||||||
|
'epoch': epoch,
|
||||||
|
'model': model.state_dict(),
|
||||||
|
'classifier': classifier_.state_dict(),
|
||||||
|
'best_acc1': best_acc1,
|
||||||
|
'optimizer': optimizer_.state_dict(),
|
||||||
|
}
|
||||||
|
torch.save(state, f'models/model_tag_{task}_{epoch}.pth')
|
||||||
|
wandb.save('models/model_{}.pth'.format(epoch))
|
Loading…
Reference in New Issue
Block a user