import os import wandb import pickle import argparse from PIL import Image import matplotlib.pyplot as plt from util import * from generate_dataset import TouchFolderLabel from linear_classifier import LinearClassifierResNet from train_mm_moco import evaluate_and_plot, compute_tsne, MultiModalMoCo import torch import torch.optim as optim from torchvision import transforms from torch.utils.data import random_split from torch.optim.lr_scheduler import StepLR from torch.utils.data import DataLoader, Dataset device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def parse_option(): parser = argparse.ArgumentParser('argument for training') parser.add_argument('--print_freq', type=int, default=10, help='print frequency') parser.add_argument('--save_freq', type=int, default=10, help='save frequency') parser.add_argument('--batch_size', type=int, default=128, help='batch_size') parser.add_argument('--num_workers', type=int, default=18, help='num of workers to use') parser.add_argument('--epochs', type=int, default=61, help='number of training epochs') parser.add_argument('--num_layers', type=int, default=5, help='number of layers in resnet') # optimization parser.add_argument('--learning_rate', type=float, default=0.03, help='learning rate') parser.add_argument('--lr_decay_epochs', type=str, default='120,160', help='where to decay lr, can be a list') parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate') parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam') parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam') parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay') parser.add_argument('--momentum', type=float, default=0.9, help='momentum') # resume path parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') # model definition parser.add_argument('--model', type=str, default='alexnet', choices=[ 'resnet50t1', 'resnet101t1', 'resnet18t1', 'resnet50t2', 'resnet101t2', 'resnet18t2', 'resnet50t3', 'resnet101t3', 'resnet18t3']) parser.add_argument('--softmax', action='store_true', help='using softmax contrastive loss rather than NCE') parser.add_argument('--feat_dim', type=int, default=128, help='dim of feat for inner product') # dataset parser.add_argument('--dataset', type=str, default='touch_and_go', choices=['touch_and_go', 'pretrain', 'touch_rough', 'touch_hard']) # specify folder parser.add_argument('--data_folder', type=str, default="dataset/", help='path to dataset') parser.add_argument('--data_loader', type=str, default='touch_and_go', choices=['touch_and_go']) parser.add_argument('--model_path', type=str, default="ckpt/mmssl", help='path to save model') # add new views parser.add_argument('--view', type=str, default='Touch', choices=['Touch']) # mixed precision setting parser.add_argument('--amp', action='store_true', help='using mixed precision') parser.add_argument('--opt_level', type=str, default='O2', choices=['O1', 'O2']) # data crop threshold parser.add_argument('--crop_low', type=float, default=0.2, help='low area in crop') # data amount parser.add_argument('--data_amount', type=int, default=100, help='how much data used') parser.add_argument('--comment', type=str, default='', help='comment') # wandb parser.add_argument('--wandb', action='store_true', help='Enable wandb') parser.add_argument('--wandb_name', type=str, default=None, help='username of wandb') opt = parser.parse_args() if (opt.data_folder is None) or (opt.model_path is None): raise ValueError('one or more of the folders is None: data_folder | model_path') return opt def logging(epoch, idx, train_loader_len, loss, acc1, acc5, losses, top1, top5, pretrain, train=True): if pretrain: print('Epoch: [{0}][{1}/{2}]\t' 'Loss {loss:.4f}'.format( epoch, idx, train_loader_len, loss=loss)) wandb.log({"training loss": loss}, step=epoch * train_loader_len + idx) else: print('Epoch: [{0}][{1}/{2}]\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( epoch, idx, train_loader_len, loss=losses, top1=top1)) if train: wandb.log({"training loss": loss, "training accuracy": acc1[0], "training top5 accuracy": acc5[0]}, step=epoch) else: wandb.log({"validation loss": loss, "validation accuracy": acc1[0], "validation top5 accuracy": acc5[0]}, step=epoch) def train(epoch, train_loader, model, optimizer, classifier=None, criterion=None, task=None, scheduler=None): # Logging setup losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # Training loop for idx, values in enumerate(train_loader): if classifier is None: inputs, _, index = values model.train() else: inputs, target = values classifier = classifier.to(device) criterion = criterion.to(device) model.eval() classifier.train() x_vision, x_tactile = inputs[:,:3,:,:], inputs[:,3:,:,:] # Augment images x_vision_q = data_transforms(x_vision).to(device) x_vision_k = data_transforms(x_vision).to(device) x_tactile_q = data_transforms(x_tactile).to(device) x_tactile_k = data_transforms(x_tactile).to(device) # Forward pass to get the loss if classifier is None: loss = model(x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, idx, len(train_loader)) # Backward pass and optimization optimizer.zero_grad() loss.backward() optimizer.step() #scheduler.step() # For logging acc1, acc5, losses, top1, top5 = 0, 0, 0, 0, 0 else: feat = model.tactile_base_q(x_tactile_q) if args.num_layers == 3: feat = model.phi_vision_q(feat) output = classifier(feat) target = target.cuda() loss = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 1)) losses.update(loss.item(), inputs.size(0)) top1.update(acc1[0], inputs.size(0)) top5.update(acc5[0], inputs.size(0)) # Backward pass and optimization optimizer.zero_grad() loss.backward() optimizer.step() #scheduler.step() if idx % 100 == 0 and idx!=0: logging(epoch, idx, len(train_loader), loss.item(), acc1, acc5, losses, top1, top5, pretrain=(classifier is None)) def val(epoch, test_loader, model, optimizer, classifier=None, criterion=None, task=None): # Logging setup losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # Training loop for idx, values in enumerate(test_loader): if classifier is None: inputs, _, index = values model.eval() else: inputs, target = values classifier = classifier.to(device) criterion = criterion.to(device) model.eval() classifier.eval() x_vision, x_tactile = inputs[:,:3,:,:], inputs[:,3:,:,:] # Augment images x_vision_q = data_transforms(x_vision).to(device) x_vision_k = data_transforms(x_vision).to(device) x_tactile_q = data_transforms(x_tactile).to(device) x_tactile_k = data_transforms(x_tactile).to(device) # Forward pass to get the loss if classifier is None: loss = model(x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, idx, len(test_loader)) # For logging acc1, acc5, losses, top1, top5 = 0, 0, 0, 0, 0 else: feat = model.tactile_base_q(x_tactile_q) if args.num_layers == 3: feat = model.phi_vision_q(feat) output = classifier(feat) target = target.cuda() loss = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 1)) losses.update(loss.item(), inputs.size(0)) top1.update(acc1[0], inputs.size(0)) top5.update(acc5[0], inputs.size(0)) if len(test_loader) > 100: if idx % 100 == 0 and idx!=0: logging(epoch, idx, len(test_loader), loss.item(), acc1, acc5, losses, top1, top5, pretrain=(classifier is None), train=False) wandb.log({"val accuracy": top1.avg.item()}, step=epoch) print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) else: if idx % 30 == 0 and idx!=0: logging(epoch, idx, len(test_loader), loss.item(), acc1, acc5, losses, top1, top5, pretrain=(classifier is None), train=False) #wandb.log({"val accuracy": top1.avg.item()}, step=epoch) print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) if classifier is None: return 0, 0, 0 return top1.avg, top5.avg, losses.avg if __name__ == "__main__": # Best Accuracy global best_acc1 best_acc1 = 0 # Parse arguments args = parse_option() # Initialize wandb wandb.login() #wandb.init(project='tac_tag') # Initialize augmentation simple_transforms = transforms.Compose([ transforms.Resize((256, 256)), #transforms.CenterCrop(500), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) data_transforms = transforms.Compose([ #transforms.RandomApply([transforms.RandomRotation(150)], p=0.50), transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), transforms.RandomApply([transforms.RandomHorizontalFlip()], p=0.50), #transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), transforms.RandomGrayscale(p=0.2), #transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.5), ]) # Initialize dataset and dataloader data_folder = args.data_folder train_sampler = None if args.dataset == 'touch_and_go' or args.dataset == 'touch_rough' or args.dataset == 'touch_hard' or args.dataset == 'pretrain': if args.dataset == 'touch_hard': print('hard') train_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='train', label='hard') val_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='test', label='hard') n_labels = 2 elif args.dataset == 'touch_rough': print('rough') train_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='train', label='rough') val_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='test', label='rough') n_labels = 2 elif args.dataset == 'touch_and_go': train_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='train') val_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='test') n_labels = 20 elif args.dataset == 'pretrain': train_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='pretrain') val_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='pretrain') else: raise NotImplementedError('dataset not supported {}'.format(args.dataset)) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) test_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) # num of samples n_data = len(train_dataset) print('number of samples: {}'.format(n_data)) # Initialize model nn_model = 'resnet18' model = MultiModalMoCo(m=0.99, T=0.2, nn_model=nn_model).to(device) # Initialize task if args.dataset == 'pretrain': task = "pretrain" elif args.dataset == 'touch_and_go': task = "material" elif args.dataset == 'touch_rough': task = "rough" elif args.dataset == 'touch_hard': task = "hard" if task == "pretrain": # Initialize optimizer modules = list(model.vision_base_q.parameters()) + list(model.tactile_base_q.parameters()) + list(model.phi_vision_q.parameters()) + list(model.phi_tactile_q.parameters()) + \ list(model.Phi_vision_q.parameters()) + list(model.Phi_tactile_q.parameters()) optimizer = optim.Adam(modules, lr=0.03) classifier, criterion = None, None else: # Initialize training #model.load_state_dict(torch.load('/media/vedant/cpsDataStorageWK/Vedant/Tactile/TAG/models/model_tag_238.pth')) model.load_state_dict(torch.load('/media/vedant/cpsDataStorageWK/Vedant/Tactile/TAG/models/model_tag_pretrain_220.pth')['model']) classifier = LinearClassifierResNet(layer=args.num_layers, n_label=n_labels) optimizer = optim.Adam(classifier.parameters(), lr=1e-4) criterion = torch.nn.CrossEntropyLoss() # Initialize wandb wandb.init(project=task+"_mm") # Train for epoch in range(args.epochs): train(epoch, train_loader, model, optimizer, classifier=classifier, criterion=criterion, task=task) if task == "pretrain": pass else: print("==> testing...") test_acc, test_acc5, test_loss = val(epoch, test_loader, model, optimizer, classifier=classifier, criterion=criterion, task=task) # save the best model print('test_acc: {}'.format(test_acc)) print('best_acc1: {}'.format(best_acc1)) if test_acc > best_acc1: best_acc1 = test_acc state = { 'opt': args, 'epoch': epoch, 'classifier': classifier.state_dict(), 'best_acc1': best_acc1, 'optimizer': optimizer.state_dict(), } save_name = f'models/model_tag_{task}_best.pth' print('saving best model!') torch.save(state, save_name) if epoch % 10 == 0: print('==> Saving...') classifier_ = classifier if task != "pretrain" else optimizer optimizer_ = optimizer best_acc1 = best_acc1 if task != "pretrain" else 0 state = { 'opt': args, 'epoch': epoch, 'model': model.state_dict(), 'classifier': classifier_.state_dict(), 'best_acc1': best_acc1, 'optimizer': optimizer_.state_dict(), } torch.save(state, f'models/model_tag_{task}_{epoch}.pth') wandb.save('models/model_{}.pth'.format(epoch))