94 lines
2.9 KiB
Python
94 lines
2.9 KiB
Python
|
from __future__ import print_function
|
||
|
|
||
|
from PIL import Image
|
||
|
|
||
|
import torch
|
||
|
import os
|
||
|
from torch.utils.data import Dataset
|
||
|
|
||
|
|
||
|
|
||
|
class TouchFolderLabel(Dataset):
|
||
|
"""Folder datasets which returns the index of the image as well
|
||
|
"""
|
||
|
|
||
|
def __init__(self, root, transform=None, target_transform=None, two_crop=False, mode='train', label='full', data_amount=100):
|
||
|
self.two_crop = two_crop
|
||
|
self.dataroot = '/media/vedant/cpsDataStorageWK/Vedant/tactile/TAG/dataset_copy/'
|
||
|
self.mode = mode
|
||
|
if mode == 'train':
|
||
|
with open(os.path.join(root, 'train.txt'),'r') as f:
|
||
|
data = f.read().split('\n')
|
||
|
elif mode == 'test':
|
||
|
with open(os.path.join(root, 'test.txt'),'r') as f:
|
||
|
data = f.read().split('\n')
|
||
|
elif mode == 'pretrain':
|
||
|
with open(os.path.join(root, 'pretrain.txt'),'r') as f:
|
||
|
data = f.read().split('\n')
|
||
|
else:
|
||
|
print('Mode other than train and test')
|
||
|
exit()
|
||
|
|
||
|
if mode == 'train' and label == 'rough':
|
||
|
with open(os.path.join(root, 'train_rough.txt'),'r') as f:
|
||
|
data = f.read().split('\n')
|
||
|
|
||
|
if mode == 'test' and label == 'rough':
|
||
|
with open(os.path.join(root, 'test_rough.txt'),'r') as f:
|
||
|
data = f.read().split('\n')
|
||
|
|
||
|
|
||
|
self.length = len(data)
|
||
|
self.env = data
|
||
|
self.transform = transform
|
||
|
self.target_transform = target_transform
|
||
|
self.label = label
|
||
|
|
||
|
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
"""
|
||
|
Args:
|
||
|
index (int): Index
|
||
|
Returns:
|
||
|
tuple: (image, target, index) where target is class_index of the target class.
|
||
|
"""
|
||
|
|
||
|
assert index < self.length,'index_A range error'
|
||
|
|
||
|
raw, target = self.env[index].strip().split(',') # mother path for A
|
||
|
target = int(target)
|
||
|
if self.label == 'hard':
|
||
|
if target == 7 or target == 8 or target == 9 or target == 11 or target == 13:
|
||
|
target = 1
|
||
|
else:
|
||
|
target = 0
|
||
|
|
||
|
idx = os.path.basename(raw)
|
||
|
dir = self.dataroot + raw[:16]
|
||
|
|
||
|
# load image and gelsight
|
||
|
A_img_path = os.path.join(dir, 'video_frame', idx)
|
||
|
A_gelsight_path = os.path.join(dir, 'gelsight_frame', idx)
|
||
|
|
||
|
A_img = Image.open(A_img_path).convert('RGB')
|
||
|
A_gel = Image.open(A_gelsight_path).convert('RGB')
|
||
|
|
||
|
|
||
|
if self.transform is not None:
|
||
|
A_img = self.transform(A_img)
|
||
|
A_gel = self.transform(A_gel)
|
||
|
else:
|
||
|
A_img = torch.Tensor(A_img)
|
||
|
A_gel = torch.Tensor(A_gel)
|
||
|
|
||
|
out = torch.cat((A_img, A_gel), dim=0)
|
||
|
|
||
|
if self.mode == 'pretrain':
|
||
|
return out, target, index
|
||
|
|
||
|
return out, target
|
||
|
|
||
|
def __len__(self):
|
||
|
"""Return the total number of images."""
|
||
|
return self.length
|