Update 'tac_ssl_tag.py'

This commit is contained in:
Vedant Dave 2023-09-12 14:06:33 +00:00
parent 29c80c6d92
commit 913e000336

View File

@ -25,7 +25,7 @@ def parse_option():
parser.add_argument('--print_freq', type=int, default=10, help='print frequency')
parser.add_argument('--save_freq', type=int, default=10, help='save frequency')
parser.add_argument('--batch_size', type=int, default=256, help='batch_size')
parser.add_argument('--batch_size', type=int, default=128, help='batch_size')
parser.add_argument('--num_workers', type=int, default=18, help='num of workers to use')
parser.add_argument('--epochs', type=int, default=61, help='number of training epochs')
parser.add_argument('--num_layers', type=int, default=5, help='number of layers in resnet')
@ -302,7 +302,7 @@ if __name__ == "__main__":
# Initialize model
nn_model = 'resnet18'
model = MultiModalMoCo(m=0.99, T=0.07, nn_model=nn_model).to(device)
model = MultiModalMoCo(m=0.99, T=0.2, nn_model=nn_model).to(device)
# Initialize task
if args.dataset == 'pretrain':
@ -316,12 +316,12 @@ if __name__ == "__main__":
if task == "pretrain":
# Initialize optimizer
modules = list(model.vision_base_q.parameters()) + list(model.tactile_base_q.parameters()) + list(model.phi_vision_q.parameters()) + list(model.phi_tactile_q.parameters())
optimizer = optim.Adam(modules, lr=0.03)
modules = list(model.vision_base_q.parameters()) + list(model.tactile_base_q.parameters()) + list(model.phi_vision_q.parameters()) + list(model.phi_tactile_q.parameters()) + \
list(model.Phi_vision_q.parameters()) + list(model.Phi_tactile_q.parameters())
classifier, criterion = None, None
else:
# Initialize training
#model.load_state_dict(torch.load('/home/cpsadmin/TAG/models/model_tag_238.pth'))
#model.load_state_dict(torch.load('/media/vedant/cpsDataStorageWK/Vedant/Tactile/TAG/models/model_tag_238.pth'))
model.load_state_dict(torch.load('/media/vedant/cpsDataStorageWK/Vedant/Tactile/TAG/models/model_tag_pretrain_220.pth')['model'])
classifier = LinearClassifierResNet(layer=args.num_layers, n_label=n_labels)
optimizer = optim.Adam(classifier.parameters(), lr=1e-4)