Update 'tac_ssl_tag.py'
This commit is contained in:
parent
29c80c6d92
commit
913e000336
@ -25,7 +25,7 @@ def parse_option():
|
||||
|
||||
parser.add_argument('--print_freq', type=int, default=10, help='print frequency')
|
||||
parser.add_argument('--save_freq', type=int, default=10, help='save frequency')
|
||||
parser.add_argument('--batch_size', type=int, default=256, help='batch_size')
|
||||
parser.add_argument('--batch_size', type=int, default=128, help='batch_size')
|
||||
parser.add_argument('--num_workers', type=int, default=18, help='num of workers to use')
|
||||
parser.add_argument('--epochs', type=int, default=61, help='number of training epochs')
|
||||
parser.add_argument('--num_layers', type=int, default=5, help='number of layers in resnet')
|
||||
@ -302,7 +302,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Initialize model
|
||||
nn_model = 'resnet18'
|
||||
model = MultiModalMoCo(m=0.99, T=0.07, nn_model=nn_model).to(device)
|
||||
model = MultiModalMoCo(m=0.99, T=0.2, nn_model=nn_model).to(device)
|
||||
|
||||
# Initialize task
|
||||
if args.dataset == 'pretrain':
|
||||
@ -316,12 +316,12 @@ if __name__ == "__main__":
|
||||
|
||||
if task == "pretrain":
|
||||
# Initialize optimizer
|
||||
modules = list(model.vision_base_q.parameters()) + list(model.tactile_base_q.parameters()) + list(model.phi_vision_q.parameters()) + list(model.phi_tactile_q.parameters())
|
||||
optimizer = optim.Adam(modules, lr=0.03)
|
||||
modules = list(model.vision_base_q.parameters()) + list(model.tactile_base_q.parameters()) + list(model.phi_vision_q.parameters()) + list(model.phi_tactile_q.parameters()) + \
|
||||
list(model.Phi_vision_q.parameters()) + list(model.Phi_tactile_q.parameters())
|
||||
classifier, criterion = None, None
|
||||
else:
|
||||
# Initialize training
|
||||
#model.load_state_dict(torch.load('/home/cpsadmin/TAG/models/model_tag_238.pth'))
|
||||
#model.load_state_dict(torch.load('/media/vedant/cpsDataStorageWK/Vedant/Tactile/TAG/models/model_tag_238.pth'))
|
||||
model.load_state_dict(torch.load('/media/vedant/cpsDataStorageWK/Vedant/Tactile/TAG/models/model_tag_pretrain_220.pth')['model'])
|
||||
classifier = LinearClassifierResNet(layer=args.num_layers, n_label=n_labels)
|
||||
optimizer = optim.Adam(classifier.parameters(), lr=1e-4)
|
||||
|
Loading…
Reference in New Issue
Block a user