From b3660fe103256821818a535696412ccdfa464849 Mon Sep 17 00:00:00 2001 From: Vedant Dave Date: Tue, 12 Sep 2023 13:16:03 +0000 Subject: [PATCH] Updating --- tac_ssl.py | 429 ++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 330 insertions(+), 99 deletions(-) diff --git a/tac_ssl.py b/tac_ssl.py index bf59f87..0cb0fca 100644 --- a/tac_ssl.py +++ b/tac_ssl.py @@ -1,118 +1,131 @@ import os +import wandb import pickle +import argparse from PIL import Image - -from train_mm_moco import evaluate_and_plot, compute_tsne, MultiModalMoCo import matplotlib.pyplot as plt +from util import * +from generate_dataset import TouchFolderLabel +from linear_classifier import LinearClassifierResNet +from train_mm_moco import evaluate_and_plot, compute_tsne, MultiModalMoCo + import torch import torch.optim as optim from torchvision import transforms from torch.utils.data import random_split +from torch.optim.lr_scheduler import StepLR from torch.utils.data import DataLoader, Dataset -from torch.utils.tensorboard import SummaryWriter device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -writer = SummaryWriter('runs/mmssl') -# Custom dataset -class CustomMultiModalDataset(Dataset): - def __init__(self, vision_folder, tactile_folder, transform=None): - self.vision_folder = vision_folder - self.tactile_folder = tactile_folder - self.transform = transform +def parse_option(): + parser = argparse.ArgumentParser('argument for training') - self.vision_files = sorted(os.listdir(vision_folder)) - self.tactile_files = sorted(os.listdir(tactile_folder)) + parser.add_argument('--print_freq', type=int, default=10, help='print frequency') + parser.add_argument('--save_freq', type=int, default=10, help='save frequency') + parser.add_argument('--batch_size', type=int, default=256, help='batch_size') + parser.add_argument('--num_workers', type=int, default=18, help='num of workers to use') + parser.add_argument('--epochs', type=int, default=61, help='number of training epochs') + parser.add_argument('--num_layers', type=int, default=5, help='number of layers in resnet') - def __len__(self): - return len(self.vision_files) + # optimization + parser.add_argument('--learning_rate', type=float, default=0.03, help='learning rate') + parser.add_argument('--lr_decay_epochs', type=str, default='120,160', help='where to decay lr, can be a list') + parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate') + parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam') + parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam') + parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay') + parser.add_argument('--momentum', type=float, default=0.9, help='momentum') - def __getitem__(self, idx): - vision_path = os.path.join(self.vision_folder, self.vision_files[idx]) - tactile_path = os.path.join(self.tactile_folder, self.tactile_files[idx]) + # resume path + parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') - vision_image = Image.open(vision_path).convert("RGB") - tactile_image = Image.open(tactile_path).convert("RGB") + # model definition + parser.add_argument('--model', type=str, default='alexnet', choices=[ + 'resnet50t1', 'resnet101t1', 'resnet18t1', + 'resnet50t2', 'resnet101t2', 'resnet18t2', + 'resnet50t3', 'resnet101t3', 'resnet18t3']) + parser.add_argument('--softmax', action='store_true', help='using softmax contrastive loss rather than NCE') + parser.add_argument('--feat_dim', type=int, default=128, help='dim of feat for inner product') - if self.transform: - vision_image = self.transform(vision_image) - tactile_image = self.transform(tactile_image) + # dataset + parser.add_argument('--dataset', type=str, default='touch_and_go', choices=['touch_and_go', 'pretrain', 'touch_rough', 'touch_hard']) - return vision_image, tactile_image + # specify folder + parser.add_argument('--data_folder', type=str, default="dataset/", help='path to dataset') + parser.add_argument('--data_loader', type=str, default='touch_and_go', choices=['touch_and_go']) + parser.add_argument('--model_path', type=str, default="ckpt/mmssl", help='path to save model') -# Initialize augmentation -simple_transforms = transforms.Compose([ - transforms.Resize((275, 275)), - #transforms.CenterCrop(500), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -]) + # add new views + parser.add_argument('--view', type=str, default='Touch', choices=['Touch']) -data_transforms = transforms.Compose([ - transforms.RandomApply([transforms.RandomRotation(150)], p=0.50), - transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), - transforms.RandomApply([transforms.RandomHorizontalFlip()], p=0.50), - #transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), - transforms.RandomGrayscale(p=0.2), - #transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.5), -]) + # mixed precision setting + parser.add_argument('--amp', action='store_true', help='using mixed precision') + parser.add_argument('--opt_level', type=str, default='O2', choices=['O1', 'O2']) -# Initialize dataset and dataloader -vision_folder = "/home/vedant/Downloads/ssvtp_data/images_rgb" -tactile_folder = "/home/vedant/Downloads/ssvtp_data/images_tac" -dataset = CustomMultiModalDataset(vision_folder, tactile_folder, transform=simple_transforms) + # data crop threshold + parser.add_argument('--crop_low', type=float, default=0.2, help='low area in crop') -preload = True -if not preload: - # Split the dataset into 80-20 - train_size = int(0.8 * len(dataset)) - test_size = len(dataset) - train_size - train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) + # data amount + parser.add_argument('--data_amount', type=int, default=100, help='how much data used') + parser.add_argument('--comment', type=str, default='', help='comment') - # Get the indices of the training and test sets - train_indices = train_dataset.indices - test_indices = test_dataset.indices + # wandb + parser.add_argument('--wandb', action='store_true', help='Enable wandb') + parser.add_argument('--wandb_name', type=str, default=None, help='username of wandb') - # Save these indices to disk - with open('indices/train_indices.pkl', 'wb') as f: - pickle.dump(train_indices, f) - - with open('indices/test_indices.pkl', 'wb') as f: - pickle.dump(test_indices, f) + opt = parser.parse_args() - # Initialize dataloaders for train and test - train_dataloader = DataLoader(train_dataset, batch_size=96, shuffle=True) - test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False) -else: - # Load the indices from disk - with open('indices/train_indices.pkl', 'rb') as f: - train_indices = pickle.load(f) - - with open('indices/test_indices.pkl', 'rb') as f: - test_indices = pickle.load(f) + if (opt.data_folder is None) or (opt.model_path is None): + raise ValueError('one or more of the folders is None: data_folder | model_path') + + return opt - # Create subset datasets and DataLoaders - train_subset = torch.utils.data.Subset(dataset, train_indices) - test_subset = torch.utils.data.Subset(dataset, test_indices) - train_dataloader = DataLoader(train_subset, batch_size=96, shuffle=True) - test_dataloader = DataLoader(test_subset, batch_size=32, shuffle=False) +def logging(epoch, idx, train_loader_len, loss, acc1, acc5, losses, top1, top5, pretrain, train=True): + if pretrain: + print('Epoch: [{0}][{1}/{2}]\t' + 'Loss {loss:.4f}'.format( + epoch, idx, train_loader_len, loss=loss)) + wandb.log({"training loss": loss}, step=epoch * train_loader_len + idx) + else: + print('Epoch: [{0}][{1}/{2}]\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( + epoch, idx, train_loader_len, loss=losses, top1=top1)) + if train: + wandb.log({"training loss": loss, + "training accuracy": acc1[0], + "training top5 accuracy": acc5[0]}, + step=epoch) + else: + wandb.log({"validation loss": loss, + "validation accuracy": acc1[0], + "validation top5 accuracy": acc5[0]}, + step=epoch) -# Initialize model -model = MultiModalMoCo(writer, K=4096, m=0.99, T=0.07).to(device) +def train(epoch, train_loader, model, optimizer, classifier=None, criterion=None, task=None, scheduler=None): + # Logging setup + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() -# Initialize optimizer -vision_module = list(model.vision_base_q.parameters()) + list(model.vision_head_intra_q.parameters()) + list(model.vision_head_inter_q.parameters()) -tactile_module = list(model.tactile_base_q.parameters()) + list(model.tactile_head_intra_q.parameters()) + list(model.tactile_head_inter_q.parameters()) -optim_vision = optim.Adam(vision_module, lr=0.1) -optim_tactile = optim.Adam(tactile_module, lr=0.1) + # Training loop + for idx, values in enumerate(train_loader): + if classifier is None: + inputs, _, index = values + model.train() + else: + inputs, target = values + classifier = classifier.to(device) + criterion = criterion.to(device) + model.eval() + classifier.train() -# Training loop -n_epochs = 500 # Number of epochs -for epoch in range(n_epochs): - for i, (x_vision, x_tactile) in enumerate(train_dataloader): + x_vision, x_tactile = inputs[:,:3,:,:], inputs[:,3:,:,:] # Augment images x_vision_q = data_transforms(x_vision).to(device) @@ -122,22 +135,240 @@ for epoch in range(n_epochs): x_tactile_k = data_transforms(x_tactile).to(device) # Forward pass to get the loss - loss = model(x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, i, len(train_dataloader)) + if classifier is None: + loss = model(x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, idx, len(train_loader)) - # Backward pass and optimization - optim_vision.zero_grad() - optim_tactile.zero_grad() - loss.backward() - optim_vision.step() - optim_tactile.step() + # Backward pass and optimization + optimizer.zero_grad() + loss.backward() + optimizer.step() + #scheduler.step() - # Logging - if i % 10 == 0: - print(f"Epoch [{epoch+1}/{n_epochs}], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}") - writer.add_scalar('training loss', loss.item(), epoch * len(train_dataloader) + i) + # For logging + acc1, acc5, losses, top1, top5 = 0, 0, 0, 0, 0 + else: + feat = model.tactile_base_q(x_tactile_q) + if args.num_layers == 3: + feat = model.phi_vision_q(feat) - # Evaluate and plot - #compute_tsne(model, test_dataloader, writer, epoch) - #evaluate_and_plot(model, test_dataloader, epoch, writer, device) - if epoch % 10 == 0: - torch.save(model.state_dict(), 'models/model.pth') \ No newline at end of file + output = classifier(feat) + target = target.cuda() + loss = criterion(output, target) + + acc1, acc5 = accuracy(output, target, topk=(1, 1)) + losses.update(loss.item(), inputs.size(0)) + top1.update(acc1[0], inputs.size(0)) + top5.update(acc5[0], inputs.size(0)) + + # Backward pass and optimization + optimizer.zero_grad() + loss.backward() + optimizer.step() + #scheduler.step() + + if idx % 100 == 0 and idx!=0: + logging(epoch, idx, len(train_loader), loss.item(), acc1, acc5, losses, top1, top5, pretrain=(classifier is None)) + + +def val(epoch, test_loader, model, optimizer, classifier=None, criterion=None, task=None): + # Logging setup + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + # Training loop + for idx, values in enumerate(test_loader): + if classifier is None: + inputs, _, index = values + model.eval() + else: + inputs, target = values + classifier = classifier.to(device) + criterion = criterion.to(device) + model.eval() + classifier.eval() + + x_vision, x_tactile = inputs[:,:3,:,:], inputs[:,3:,:,:] + + # Augment images + x_vision_q = data_transforms(x_vision).to(device) + x_vision_k = data_transforms(x_vision).to(device) + + x_tactile_q = data_transforms(x_tactile).to(device) + x_tactile_k = data_transforms(x_tactile).to(device) + + # Forward pass to get the loss + if classifier is None: + loss = model(x_vision_q, x_vision_k, x_tactile_q, x_tactile_k, epoch, idx, len(test_loader)) + + # For logging + acc1, acc5, losses, top1, top5 = 0, 0, 0, 0, 0 + else: + feat = model.tactile_base_q(x_tactile_q) + if args.num_layers == 3: + feat = model.phi_vision_q(feat) + + output = classifier(feat) + target = target.cuda() + loss = criterion(output, target) + + acc1, acc5 = accuracy(output, target, topk=(1, 1)) + losses.update(loss.item(), inputs.size(0)) + top1.update(acc1[0], inputs.size(0)) + top5.update(acc5[0], inputs.size(0)) + + + if len(test_loader) > 100: + if idx % 100 == 0 and idx!=0: + logging(epoch, idx, len(test_loader), loss.item(), acc1, acc5, losses, top1, top5, pretrain=(classifier is None), train=False) + wandb.log({"val accuracy": top1.avg.item()}, step=epoch) + print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) + else: + if idx % 30 == 0 and idx!=0: + logging(epoch, idx, len(test_loader), loss.item(), acc1, acc5, losses, top1, top5, pretrain=(classifier is None), train=False) + #wandb.log({"val accuracy": top1.avg.item()}, step=epoch) + print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) + + if classifier is None: + return 0, 0, 0 + + return top1.avg, top5.avg, losses.avg + + +if __name__ == "__main__": + # Best Accuracy + global best_acc1 + best_acc1 = 0 + + # Parse arguments + args = parse_option() + + # Initialize wandb + wandb.login() + #wandb.init(project='tac_tag') + + # Initialize augmentation + simple_transforms = transforms.Compose([ + transforms.Resize((256, 256)), + #transforms.CenterCrop(500), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + data_transforms = transforms.Compose([ + #transforms.RandomApply([transforms.RandomRotation(150)], p=0.50), + transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), + transforms.RandomApply([transforms.RandomHorizontalFlip()], p=0.50), + #transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), + transforms.RandomGrayscale(p=0.2), + #transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.5), + ]) + + # Initialize dataset and dataloader + data_folder = args.data_folder + train_sampler = None + if args.dataset == 'touch_and_go' or args.dataset == 'touch_rough' or args.dataset == 'touch_hard' or args.dataset == 'pretrain': + if args.dataset == 'touch_hard': + print('hard') + train_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='train', label='hard') + val_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='test', label='hard') + n_labels = 2 + elif args.dataset == 'touch_rough': + print('rough') + train_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='train', label='rough') + val_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='test', label='rough') + n_labels = 2 + elif args.dataset == 'touch_and_go': + train_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='train') + val_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='test') + n_labels = 20 + elif args.dataset == 'pretrain': + train_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='pretrain') + val_dataset = TouchFolderLabel(data_folder, transform=simple_transforms, mode='pretrain') + else: + raise NotImplementedError('dataset not supported {}'.format(args.dataset)) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) + + test_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) + + # num of samples + n_data = len(train_dataset) + print('number of samples: {}'.format(n_data)) + + # Initialize model + nn_model = 'resnet18' + model = MultiModalMoCo(m=0.99, T=0.07, nn_model=nn_model).to(device) + + # Initialize task + if args.dataset == 'pretrain': + task = "pretrain" + elif args.dataset == 'touch_and_go': + task = "material" + elif args.dataset == 'touch_rough': + task = "rough" + elif args.dataset == 'touch_hard': + task = "hard" + + if task == "pretrain": + # Initialize optimizer + modules = list(model.vision_base_q.parameters()) + list(model.tactile_base_q.parameters()) + list(model.phi_vision_q.parameters()) + list(model.phi_tactile_q.parameters()) + optimizer = optim.Adam(modules, lr=0.03) + classifier, criterion = None, None + else: + # Initialize training + #model.load_state_dict(torch.load('/home/cpsadmin/TAG/models/model_tag_238.pth')) + model.load_state_dict(torch.load('/media/vedant/cpsDataStorageWK/Vedant/Tactile/TAG/models/model_tag_pretrain_220.pth')['model']) + classifier = LinearClassifierResNet(layer=args.num_layers, n_label=n_labels) + optimizer = optim.Adam(classifier.parameters(), lr=1e-4) + criterion = torch.nn.CrossEntropyLoss() + + + # Initialize wandb + wandb.init(project=task+"_mm") + + # Train + for epoch in range(args.epochs): + train(epoch, train_loader, model, optimizer, classifier=classifier, criterion=criterion, task=task) + + if task == "pretrain": + pass + else: + print("==> testing...") + test_acc, test_acc5, test_loss = val(epoch, test_loader, model, optimizer, classifier=classifier, criterion=criterion, task=task) + + # save the best model + print('test_acc: {}'.format(test_acc)) + print('best_acc1: {}'.format(best_acc1)) + if test_acc > best_acc1: + best_acc1 = test_acc + state = { + 'opt': args, + 'epoch': epoch, + 'classifier': classifier.state_dict(), + 'best_acc1': best_acc1, + 'optimizer': optimizer.state_dict(), + } + save_name = f'models/model_tag_{task}_best.pth' + print('saving best model!') + torch.save(state, save_name) + + if epoch % 10 == 0: + print('==> Saving...') + classifier_ = classifier if task != "pretrain" else optimizer + optimizer_ = optimizer + best_acc1 = best_acc1 if task != "pretrain" else 0 + state = { + 'opt': args, + 'epoch': epoch, + 'model': model.state_dict(), + 'classifier': classifier_.state_dict(), + 'best_acc1': best_acc1, + 'optimizer': optimizer_.state_dict(), + } + torch.save(state, f'models/model_tag_{task}_{epoch}.pth') + wandb.save('models/model_{}.pth'.format(epoch)) \ No newline at end of file