diff --git a/tac_ssl_tag.py b/tac_ssl_tag.py index 69159a3..712d51e 100755 --- a/tac_ssl_tag.py +++ b/tac_ssl_tag.py @@ -318,6 +318,7 @@ if __name__ == "__main__": # Initialize optimizer modules = list(model.vision_base_q.parameters()) + list(model.tactile_base_q.parameters()) + list(model.phi_vision_q.parameters()) + list(model.phi_tactile_q.parameters()) + \ list(model.Phi_vision_q.parameters()) + list(model.Phi_tactile_q.parameters()) + optimizer = optim.Adam(modules, lr=0.03) classifier, criterion = None, None else: # Initialize training