TVSSL/read_data.py

20 lines
487 B
Python
Raw Normal View History

2023-09-12 13:32:15 +00:00
import torch
from train_mm_moco import MultiModalMoCo
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Initialize model
nn_model = 'resnet18'
model = MultiModalMoCo(m=0.99, T=0.07, nn_model=nn_model).to(device)
# Load weights
saved_state = torch.load('/media/vedant/cpsDataStorageWK/Vedant/Tactile/TAG/models/model_tag_material_best.pth')
epoch = saved_state['epoch']
best_acc1 = saved_state['best_acc1']
args = saved_state['opt']
print(epoch, best_acc1)