from __future__ import print_function from PIL import Image import torch import os from torch.utils.data import Dataset class TouchFolderLabel(Dataset): """Folder datasets which returns the index of the image as well """ def __init__(self, root, transform=None, target_transform=None, two_crop=False, mode='train', label='full', data_amount=100): self.two_crop = two_crop self.dataroot = '/media/vedant/cpsDataStorageWK/Vedant/tactile/TAG/dataset_copy/' self.mode = mode if mode == 'train': with open(os.path.join(root, 'train.txt'),'r') as f: data = f.read().split('\n') elif mode == 'test': with open(os.path.join(root, 'test.txt'),'r') as f: data = f.read().split('\n') elif mode == 'pretrain': with open(os.path.join(root, 'pretrain.txt'),'r') as f: data = f.read().split('\n') else: print('Mode other than train and test') exit() if mode == 'train' and label == 'rough': with open(os.path.join(root, 'train_rough.txt'),'r') as f: data = f.read().split('\n') if mode == 'test' and label == 'rough': with open(os.path.join(root, 'test_rough.txt'),'r') as f: data = f.read().split('\n') self.length = len(data) self.env = data self.transform = transform self.target_transform = target_transform self.label = label def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target, index) where target is class_index of the target class. """ assert index < self.length,'index_A range error' raw, target = self.env[index].strip().split(',') # mother path for A target = int(target) if self.label == 'hard': if target == 7 or target == 8 or target == 9 or target == 11 or target == 13: target = 1 else: target = 0 idx = os.path.basename(raw) dir = self.dataroot + raw[:16] # load image and gelsight A_img_path = os.path.join(dir, 'video_frame', idx) A_gelsight_path = os.path.join(dir, 'gelsight_frame', idx) A_img = Image.open(A_img_path).convert('RGB') A_gel = Image.open(A_gelsight_path).convert('RGB') if self.transform is not None: A_img = self.transform(A_img) A_gel = self.transform(A_gel) else: A_img = torch.Tensor(A_img) A_gel = torch.Tensor(A_gel) out = torch.cat((A_img, A_gel), dim=0) if self.mode == 'pretrain': return out, target, index return out, target def __len__(self): """Return the total number of images.""" return self.length