TVSSL/generate_dataset.py

94 lines
2.9 KiB
Python
Raw Permalink Normal View History

2023-09-12 13:32:15 +00:00
from __future__ import print_function
from PIL import Image
import torch
import os
from torch.utils.data import Dataset
class TouchFolderLabel(Dataset):
"""Folder datasets which returns the index of the image as well
"""
def __init__(self, root, transform=None, target_transform=None, two_crop=False, mode='train', label='full', data_amount=100):
self.two_crop = two_crop
self.dataroot = '/media/vedant/cpsDataStorageWK/Vedant/tactile/TAG/dataset_copy/'
self.mode = mode
if mode == 'train':
with open(os.path.join(root, 'train.txt'),'r') as f:
data = f.read().split('\n')
elif mode == 'test':
with open(os.path.join(root, 'test.txt'),'r') as f:
data = f.read().split('\n')
elif mode == 'pretrain':
with open(os.path.join(root, 'pretrain.txt'),'r') as f:
data = f.read().split('\n')
else:
print('Mode other than train and test')
exit()
if mode == 'train' and label == 'rough':
with open(os.path.join(root, 'train_rough.txt'),'r') as f:
data = f.read().split('\n')
if mode == 'test' and label == 'rough':
with open(os.path.join(root, 'test_rough.txt'),'r') as f:
data = f.read().split('\n')
self.length = len(data)
self.env = data
self.transform = transform
self.target_transform = target_transform
self.label = label
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target, index) where target is class_index of the target class.
"""
assert index < self.length,'index_A range error'
raw, target = self.env[index].strip().split(',') # mother path for A
target = int(target)
if self.label == 'hard':
if target == 7 or target == 8 or target == 9 or target == 11 or target == 13:
target = 1
else:
target = 0
idx = os.path.basename(raw)
dir = self.dataroot + raw[:16]
# load image and gelsight
A_img_path = os.path.join(dir, 'video_frame', idx)
A_gelsight_path = os.path.join(dir, 'gelsight_frame', idx)
A_img = Image.open(A_img_path).convert('RGB')
A_gel = Image.open(A_gelsight_path).convert('RGB')
if self.transform is not None:
A_img = self.transform(A_img)
A_gel = self.transform(A_gel)
else:
A_img = torch.Tensor(A_img)
A_gel = torch.Tensor(A_gel)
out = torch.cat((A_img, A_gel), dim=0)
if self.mode == 'pretrain':
return out, target, index
return out, target
def __len__(self):
"""Return the total number of images."""
return self.length