From 913e00033656729bc5719a16891a0e1ffd28e1a3 Mon Sep 17 00:00:00 2001 From: Vedant Dave Date: Tue, 12 Sep 2023 14:06:33 +0000 Subject: [PATCH] Update 'tac_ssl_tag.py' --- tac_ssl_tag.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tac_ssl_tag.py b/tac_ssl_tag.py index 0cb0fca..69159a3 100755 --- a/tac_ssl_tag.py +++ b/tac_ssl_tag.py @@ -25,7 +25,7 @@ def parse_option(): parser.add_argument('--print_freq', type=int, default=10, help='print frequency') parser.add_argument('--save_freq', type=int, default=10, help='save frequency') - parser.add_argument('--batch_size', type=int, default=256, help='batch_size') + parser.add_argument('--batch_size', type=int, default=128, help='batch_size') parser.add_argument('--num_workers', type=int, default=18, help='num of workers to use') parser.add_argument('--epochs', type=int, default=61, help='number of training epochs') parser.add_argument('--num_layers', type=int, default=5, help='number of layers in resnet') @@ -302,7 +302,7 @@ if __name__ == "__main__": # Initialize model nn_model = 'resnet18' - model = MultiModalMoCo(m=0.99, T=0.07, nn_model=nn_model).to(device) + model = MultiModalMoCo(m=0.99, T=0.2, nn_model=nn_model).to(device) # Initialize task if args.dataset == 'pretrain': @@ -316,12 +316,12 @@ if __name__ == "__main__": if task == "pretrain": # Initialize optimizer - modules = list(model.vision_base_q.parameters()) + list(model.tactile_base_q.parameters()) + list(model.phi_vision_q.parameters()) + list(model.phi_tactile_q.parameters()) - optimizer = optim.Adam(modules, lr=0.03) + modules = list(model.vision_base_q.parameters()) + list(model.tactile_base_q.parameters()) + list(model.phi_vision_q.parameters()) + list(model.phi_tactile_q.parameters()) + \ + list(model.Phi_vision_q.parameters()) + list(model.Phi_tactile_q.parameters()) classifier, criterion = None, None else: # Initialize training - #model.load_state_dict(torch.load('/home/cpsadmin/TAG/models/model_tag_238.pth')) + #model.load_state_dict(torch.load('/media/vedant/cpsDataStorageWK/Vedant/Tactile/TAG/models/model_tag_238.pth')) model.load_state_dict(torch.load('/media/vedant/cpsDataStorageWK/Vedant/Tactile/TAG/models/model_tag_pretrain_220.pth')['model']) classifier = LinearClassifierResNet(layer=args.num_layers, n_label=n_labels) optimizer = optim.Adam(classifier.parameters(), lr=1e-4)