diff --git a/README.md b/README.md new file mode 100644 index 0000000..373029b --- /dev/null +++ b/README.md @@ -0,0 +1,23 @@ +This is the coding part of the lecture Deep Representation Learning in PyTorch. + +python 3.10 + +#### Autoencoder +In this demo we will implement a simple autoencoder. The autoencoder will be trained on the MNIST dataset. The autoencoder will be implemented in the file autoencoder.py. The file autoencoder.py contains a class Autoencoder on the MNIST dataset. +We compare the performance of the fully connected autoencoder with a convolutional autoencoder. +Jupyter notebooks: +* autoencoder_mnist.ipynb +* autoencoder_cifar10.ipynb + +#### Contractive Learning (SimCLR) +In this demo we implemented the SimClR [1] algorithm and trained it on the cifar10 dataset. + +Download the pretrained encoder [here](https://cloud.cps.unileoben.ac.at/index.php/s/feHYqRHwDy7mMDm) and put it in the folder `runs`. + +The code was adapted from this [repo](https://github.com/sthalles/SimCLR/tree/master) + +Jupyter notebooks: +* simclr.ipynb + + +[1] Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey Hinton. 2020. A simple framework for contrastive learning of visual representations. In Proceedings of the 37th International Conference on Machine Learning (ICML'20), Vol. 119. JMLR.org, Article 149, 1597–1607. \ No newline at end of file diff --git a/autoencoder_cifar10.ipynb b/autoencoder_cifar10.ipynb new file mode 100644 index 0000000..5f61ef2 --- /dev/null +++ b/autoencoder_cifar10.ipynb @@ -0,0 +1,1622 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "| Credentials | |\n", + "|----|----------------------------------|\n", + "|Host | Montanuniversitaet Leoben |\n", + "|Web | https://cps.unileoben.ac.at |\n", + "|Mail | cps@unileoben.ac.at |\n", + "|Author | Fotios Lygerakis |\n", + "|Corresponding Authors | fotios.lygerakis@unileoben.ac.at |\n", + "|Last edited | 28.09.2023 |" + ], + "metadata": { + "collapsed": false + }, + "id": "25b3e26209aea07e" + }, + { + "cell_type": "markdown", + "id": "1490260facaff836", + "metadata": { + "collapsed": false + }, + "source": [ + "In the first part of this tutorial we will build a fully connected MLP Autoencoder on the CIFAR10 dataset. Then we will perform linear probing on the encoder features to see how well they perform on a linear classification task." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "fc18830bb6f8d534", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:41:16.496631884Z", + "start_time": "2023-10-03T12:41:14.795432603Z" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import matplotlib.pyplot as plt\n", + "from torchvision import datasets, transforms\n", + "from sklearn.neighbors import KNeighborsClassifier\n", + "from sklearn.metrics import adjusted_rand_score\n", + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn.cluster import KMeans\n", + "from tqdm import tqdm" + ] + }, + { + "cell_type": "markdown", + "id": "1bad4bd03deb5b7e", + "metadata": { + "collapsed": false + }, + "source": [ + "Set random seed" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "27dd48e60ae7dd9e", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:41:16.504621133Z", + "start_time": "2023-10-03T12:41:16.497289235Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "" + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Set random seed\n", + "torch.manual_seed(0)" + ] + }, + { + "cell_type": "markdown", + "id": "cc7f167a33227e94", + "metadata": { + "collapsed": false + }, + "source": [ + "Load the CIFAR10 dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "34248e8bc2678fd3", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:41:18.687951160Z", + "start_time": "2023-10-03T12:41:17.491688103Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Files already downloaded and verified\n", + "Files already downloaded and verified\n" + ] + } + ], + "source": [ + "# Define the transformations\n", + "transform = transforms.Compose([transforms.ToTensor(),\n", + " transforms.Normalize((0.5,), (0.5,))])\n", + "# Download and load the training data\n", + "trainset = datasets.CIFAR10('data', download=True, train=True, transform=transform)\n", + "# Download and load the test data\n", + "testset = datasets.CIFAR10('data', download=True, train=False, transform=transform)\n" + ] + }, + { + "cell_type": "markdown", + "id": "928dfac955d0d778", + "metadata": { + "collapsed": false + }, + "source": [ + "Print some examples from the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "87c6eae807f51118", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:41:19.826292939Z", + "start_time": "2023-10-03T12:41:19.492130184Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Get the first 10 samples from CIFAR10\n", + "dataiter = iter(trainset)\n", + "images, labels = [], []\n", + "for i in range(10):\n", + " image, label = next(dataiter)\n", + " images.append(image)\n", + " labels.append(label)\n", + "\n", + "# CIFAR10 label names\n", + "cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n", + "\n", + "# Plot the CIFAR10 samples\n", + "fig, axes = plt.subplots(2, 5, figsize=(10, 4))\n", + "for ax, img, lbl in zip(axes.ravel(), images, labels):\n", + " ax.imshow(img.permute(1, 2, 0).numpy() * 0.5 + 0.5) # denormalize\n", + " ax.set_title(f'Label: {cifar10_classes[lbl]}')\n", + " ax.axis('off')\n", + "plt.tight_layout()\n", + "plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "id": "e4e25962ef8e5b0d", + "metadata": { + "collapsed": false + }, + "source": [ + "Define the MLP and Convolutional Autoencoder" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "26f2513d92b78e1e", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:41:21.511063666Z", + "start_time": "2023-10-03T12:41:21.485979560Z" + } + }, + "outputs": [], + "source": [ + "\n", + "class Autoencoder(nn.Module):\n", + " def __init__(self, input_size, hidden_size, type='mlp'):\n", + " super(Autoencoder, self).__init__()\n", + " self.type = type\n", + " if self.type == 'mlp':\n", + " self.encoder = nn.Sequential(\n", + " nn.Linear(input_size, hidden_size),\n", + " nn.ReLU(True))\n", + " self.decoder = nn.Sequential(\n", + " nn.Linear(hidden_size, input_size),\n", + " nn.ReLU(True),\n", + " nn.Sigmoid()\n", + " )\n", + " elif self.type == 'cnn':\n", + " # Encoder module for CIFAR10\n", + " self.encoder = nn.Sequential(\n", + " nn.Conv2d(input_size, 16, 3, stride=2, padding=1), # 16x16x16\n", + " nn.ReLU(True),\n", + " nn.Conv2d(16, 32, 3, stride=2, padding=1), # 8x8x32\n", + " nn.ReLU(True),\n", + " nn.Conv2d(32, 64, 3, stride=2, padding=1), # 4x4x64\n", + " nn.ReLU(True)\n", + " )\n", + " # Decoder module for CIFAR10\n", + " self.decoder = nn.Sequential(\n", + " nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), # 8x8x32\n", + " nn.ReLU(True),\n", + " nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1), # 16x16x16\n", + " nn.ReLU(True),\n", + " nn.ConvTranspose2d(16, input_size, 3, stride=2, padding=1, output_padding=1), # 32x32x3\n", + " nn.Sigmoid()\n", + " )\n", + " else:\n", + " raise ValueError(f\"Unknown Autoencoder type: {type}\")\n", + " \n", + " def forward(self, x):\n", + " x = self.encoder(x)\n", + " x = self.decoder(x)\n", + " return x\n" + ] + }, + { + "cell_type": "markdown", + "id": "91a01313b4d95274", + "metadata": { + "collapsed": false + }, + "source": [ + "Check if GPU support is available" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "67006b35b75d8dff", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:41:22.884901160Z", + "start_time": "2023-10-03T12:41:22.860000300Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cuda\n" + ] + } + ], + "source": [ + "# device\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(device)" + ] + }, + { + "cell_type": "markdown", + "id": "8eebf70cb27640d5", + "metadata": { + "collapsed": false + }, + "source": [ + "Define the training function" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "5f96f7be13984747", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:41:23.954782535Z", + "start_time": "2023-10-03T12:41:23.943574306Z" + } + }, + "outputs": [], + "source": [ + "# Define the training function\n", + "def train(model, train_loader, optimizer, criterion, epoch, verbose=True):\n", + " model.train()\n", + " train_loss = 0\n", + " for i, (data, _) in enumerate(train_loader):\n", + " # check the type of autoencoder and modify the input data accordingly\n", + " if model.type == 'mlp':\n", + " data = data.view(data.size(0), -1)\n", + " data = data.to(device)\n", + " optimizer.zero_grad()\n", + " output = model(data)\n", + " loss = criterion(output, data)\n", + " loss.backward()\n", + " train_loss += loss.item()\n", + " optimizer.step()\n", + " train_loss /= len(train_loader.dataset)\n", + " if verbose:\n", + " print(f'{model.type}====> Epoch: {epoch} Average loss: {train_loss:.4f}') \n", + " return train_loss" + ] + }, + { + "cell_type": "markdown", + "id": "5f6386edcab6b1e4", + "metadata": { + "collapsed": false + }, + "source": [ + "The evaluation functions for the linear classification" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "b2c4483492fdd427", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:41:24.972220830Z", + "start_time": "2023-10-03T12:41:24.960696742Z" + } + }, + "outputs": [], + "source": [ + "# Extract encoded representations for a given loader\n", + "def extract_features(loader, model):\n", + " features = []\n", + " labels = []\n", + " model.eval()\n", + " with torch.no_grad():\n", + " for data in loader:\n", + " img, label = data\n", + " if model.type == 'mlp':\n", + " img = img.view(img.size(0), -1)\n", + " img = img.to(device)\n", + " feature = model.encoder(img)\n", + " if model.type == 'cnn':\n", + " feature = feature.view(feature.size(0), -1) # Flatten the CNN encoded features\n", + " features.append(feature)\n", + " labels.append(label)\n", + " return torch.cat(features), torch.cat(labels)\n", + "\n", + "# Define the loss test function\n", + "def test_loss(model, test_loader, criterion, verbose=True):\n", + " model.eval()\n", + " eval_loss = 0\n", + " with torch.no_grad():\n", + " for i, (data, _) in enumerate(test_loader):\n", + " # check the type of autoencoder and modify the input data accordingly\n", + " if model.type == 'mlp':\n", + " data = data.view(data.size(0), -1)\n", + " data = data.to(device)\n", + " output = model(data)\n", + " eval_loss += criterion(output, data).item()\n", + " eval_loss /= len(test_loader.dataset)\n", + " if verbose:\n", + " print('====> Test set loss: {:.4f}'.format(eval_loss))\n", + " return eval_loss\n", + "\n", + "# Define the linear classification test function\n", + "def test_linear(encoded_train, train_labels, encoded_test, test_labels):\n", + " train_features_np = encoded_train.cpu().numpy()\n", + " train_labels_np = train_labels.cpu().numpy()\n", + " test_features_np = encoded_test.cpu().numpy()\n", + " test_labels_np = test_labels.cpu().numpy()\n", + " \n", + " # Apply logistic regression on train features and labels\n", + " logistic_regression = LogisticRegression(random_state=0, max_iter=100).fit(train_features_np, train_labels_np)\n", + " print(f\"Train accuracy: {logistic_regression.score(train_features_np, train_labels_np)}\")\n", + " # Apply logistic regression on test features and labels\n", + " test_accuracy = logistic_regression.score(test_features_np, test_labels_np)\n", + " print(f\"Test accuracy: {test_accuracy}\")\n", + " return test_accuracy\n", + "\n", + "\n", + "def test_clustering(encoded_features, true_labels):\n", + " encoded_features_np = encoded_features.cpu().numpy()\n", + " true_labels_np = true_labels.cpu().numpy()\n", + " \n", + " # Apply k-means clustering\n", + " kmeans = KMeans(n_clusters=10, n_init=10, random_state=0).fit(encoded_features_np)\n", + " cluster_labels = kmeans.labels_\n", + " \n", + " # Evaluate clustering results using Adjusted Rand Index\n", + " ari_score = adjusted_rand_score(true_labels_np, cluster_labels)\n", + " print(f\"Clustering ARI score: {ari_score}\")\n", + " return ari_score\n", + "\n", + "def knn_classifier(encoded_train, train_labels, encoded_test, test_labels, k=5):\n", + " encoded_train_np = encoded_train.cpu().numpy()\n", + " encoded_test_np = encoded_test.cpu().numpy()\n", + " train_labels_np = train_labels.cpu().numpy()\n", + " test_labels_np = test_labels.cpu().numpy()\n", + " \n", + " # Apply k-nearest neighbors classification\n", + " knn = KNeighborsClassifier(n_neighbors=k).fit(encoded_train_np, train_labels_np)\n", + " accuracy_score = knn.score(encoded_test_np, test_labels_np)\n", + " print(f\"KNN accuracy: {accuracy_score}\")\n", + " return accuracy_score\n", + "\n", + "def test(model, train_loader, test_loader, criterion):\n", + " # Extract features once for all tests\n", + " encoded_train, train_labels = extract_features(train_loader, model)\n", + " encoded_test, test_labels = extract_features(test_loader, model)\n", + " print(f\"{model.type} Autoencoder\")\n", + " results = {\n", + " 'reconstruction_loss': test_loss(model, test_loader, criterion),\n", + " 'linear_classification_accuracy': test_linear(encoded_train, train_labels, encoded_test, test_labels),\n", + " 'knn_classification_accuracy': knn_classifier(encoded_train, train_labels, encoded_test, test_labels),\n", + " 'clustering_ari_score': test_clustering(encoded_test, test_labels)\n", + " }\n", + " \n", + " # Save results to a log file\n", + " with open(\"evaluation_results.log\", \"w\") as log_file:\n", + " for key, value in results.items():\n", + " log_file.write(f\"{key}: {value}\")\n", + " \n", + " return results\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "bcb22bc5af9fb014", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:41:26.789225198Z", + "start_time": "2023-10-03T12:41:26.359259721Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MLP AE parameters: 789632\n", + "CNN AE parameters: 47107\n" + ] + } + ], + "source": [ + "# Define the training parameters for the fully connected MLP Autoencoder\n", + "batch_size = 32\n", + "epochs = 5\n", + "input_size = trainset.data.shape[1] * trainset.data.shape[2] * trainset.data.shape[3]\n", + "hidden_size = 128\n", + "train_frequency = epochs\n", + "test_frequency = epochs\n", + "\n", + "# Create the fully connected MLP Autoencoder\n", + "ae = Autoencoder(input_size, hidden_size, type='mlp').to(device)\n", + "input_size = trainset.data.shape[3]\n", + "cnn_ae = Autoencoder(input_size, hidden_size, type='cnn').to(device)\n", + "# print the models' number of parameters\n", + "print(f\"MLP AE parameters: {sum(p.numel() for p in ae.parameters())}\")\n", + "print(f\"CNN AE parameters: {sum(p.numel() for p in cnn_ae.parameters())}\")\n", + "\n", + "# Define the loss function and optimizer\n", + "criterion = nn.MSELoss()\n", + "optimizer = optim.Adam(ae.parameters(), lr=1e-3)\n", + "optimizer_cnn = optim.Adam(cnn_ae.parameters(), lr=1e-3)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f1626ce45bb25883", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:41:27.399274149Z", + "start_time": "2023-10-03T12:41:27.378668195Z" + } + }, + "outputs": [], + "source": [ + "# Create the train and test dataloaders\n", + "train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)\n", + "test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "7472159fdc5f2532", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:47:02.996829577Z", + "start_time": "2023-10-03T12:41:27.969843962Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 80%|████████ | 4/5 [01:21<00:22, 22.85s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mlp====> Epoch: 5 Average loss: 0.0174\n", + "cnn====> Epoch: 5 Average loss: 0.0045\n", + "mlp Autoencoder\n", + "====> Test set loss: 0.0172\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy: 0.39372\n", + "Test accuracy: 0.3933\n", + "KNN accuracy: 0.3391\n", + "Clustering ARI score: 0.03965075119494781\n", + "cnn Autoencoder\n", + "====> Test set loss: 0.0044\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy: 0.41372\n", + "Test accuracy: 0.3981\n", + "KNN accuracy: 0.3545\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 5/5 [05:34<00:00, 67.00s/it] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Clustering ARI score: 0.051552759482472406\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "test_mlp = []\n", + "test_cnn = []\n", + "# Train the model\n", + "for epoch in tqdm(range(1, epochs + 1)):\n", + " verbose = True if epoch % train_frequency == 0 else False\n", + " train(ae, train_loader, optimizer, criterion, epoch, verbose)\n", + " train(cnn_ae, train_loader, optimizer_cnn, criterion, epoch, verbose)\n", + "\n", + " # test every n epochs\n", + " if epoch % test_frequency == 0:\n", + " restults_dic = test(ae, train_loader, test_loader, criterion)\n", + " test_mlp.append([restults_dic['reconstruction_loss'], restults_dic['linear_classification_accuracy'], restults_dic['knn_classification_accuracy'], restults_dic['clustering_ari_score']])\n", + " restults_dic = test(cnn_ae, train_loader, test_loader, criterion)\n", + " test_cnn.append([restults_dic['reconstruction_loss'], restults_dic['linear_classification_accuracy'], restults_dic['knn_classification_accuracy'], restults_dic['clustering_ari_score']])" + ] + }, + { + "cell_type": "markdown", + "id": "10639256e342a159", + "metadata": { + "collapsed": false + }, + "source": [ + "Compare the evaluation results of the MLP and CNN Autoencoders" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "50bb4c3c58af09ee", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T13:06:01.418631529Z", + "start_time": "2023-10-03T13:06:01.405317147Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model Reconstruction Loss Linear Accuracy KNN Accuracy Clustering ARI \n", + "MLP AE 0.0172 0.3933 0.3391 0.0397 \n", + "CNN AE 0.0044 0.3981 0.3545 0.0516 \n" + ] + } + ], + "source": [ + "print(f\"{'Model':<10} {'Reconstruction Loss':<20} {'Linear Accuracy':<20} {'KNN Accuracy':<20} {'Clustering ARI':<20}\")\n", + "print(f\"{'MLP AE':<10} {test_mlp[-1][0]:<20.4f} {test_mlp[-1][1]:<20.4f} {test_mlp[-1][2]:<20.4f} {test_mlp[-1][3]:<20.4f}\")\n", + "print(f\"{'CNN AE':<10} {test_cnn[-1][0]:<20.4f} {test_cnn[-1][1]:<20.4f} {test_cnn[-1][2]:<20.4f} {test_cnn[-1][3]:<20.4f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "b9201d1403781706", + "metadata": { + "collapsed": false + }, + "source": [ + "Develop a linear classifier with fully connected layers" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "1612800950703181", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:47:03.004530043Z", + "start_time": "2023-10-03T12:47:02.996038048Z" + } + }, + "outputs": [], + "source": [ + "# Define the fully connected classifier for MNIST\n", + "class DenseClassifier(nn.Module):\n", + " def __init__(self, input_size=784, hidden_size=500, num_classes=10):\n", + " super(DenseClassifier, self).__init__()\n", + " self.type = 'mlp'\n", + " self.encoder = nn.Sequential(\n", + " nn.Linear(input_size, hidden_size),\n", + " nn.ReLU(True))\n", + " self.fc1 = nn.Linear(hidden_size, num_classes)\n", + "\n", + " def forward(self, x):\n", + " x = x.view(x.size(0), -1) # Flatten the input tensor\n", + " x = self.encoder(x)\n", + " x = self.fc1(x)\n", + " return x\n" + ] + }, + { + "cell_type": "markdown", + "id": "35db4190e9c7f716", + "metadata": { + "collapsed": false + }, + "source": [ + "Develop a non-linear classifier with convolutional layers" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "cb2dfcf75113fd0b", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:47:03.041289873Z", + "start_time": "2023-10-03T12:47:02.999128924Z" + } + }, + "outputs": [], + "source": [ + "# cnn classifier\n", + "class CNNClassifier(nn.Module):\n", + " def __init__(self, input_size, hidden_size=128, num_classes=10):\n", + " super(CNNClassifier, self).__init__()\n", + " self.type = 'cnn'\n", + " # Encoder (Feature extractor)\n", + " self.encoder = nn.Sequential(\n", + " nn.Conv2d(input_size, 16, 3, stride=2, padding=1), # 16x16x16\n", + " nn.ReLU(True),\n", + " nn.Conv2d(16, 32, 3, stride=2, padding=1), # 8x8x32\n", + " nn.ReLU(True),\n", + " nn.Conv2d(32, 64, 3, stride=2, padding=1), # 4x4x64\n", + " nn.ReLU(True)\n", + " )\n", + " \n", + " # Classifier\n", + " self.classifier = nn.Sequential(\n", + " nn.Flatten(), \n", + " nn.Linear(4*4*64, hidden_size),\n", + " nn.ReLU(True),\n", + " nn.Linear(hidden_size, num_classes)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.encoder(x)\n", + " x = self.classifier(x)\n", + " return x\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "id": "7c3cf2371479da0c", + "metadata": { + "collapsed": false + }, + "source": [ + "Train and test functions for the non-linear classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "ac980d25bd8a3dd3", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:47:03.062646362Z", + "start_time": "2023-10-03T12:47:03.020366991Z" + } + }, + "outputs": [], + "source": [ + "# Train for the classifier\n", + "def train_classifier(model, train_loader, optimizer, criterion, epoch, verbose=True):\n", + " model.train()\n", + " train_loss = 0\n", + " correct = 0 \n", + " for i, (data, target) in enumerate(train_loader):\n", + " if model.type == 'cnn':\n", + " data = data.to(device)\n", + " else:\n", + " data = data.view(data.size(0), -1)\n", + " data = data.to(device)\n", + " target = target.to(device)\n", + " optimizer.zero_grad()\n", + " output = model(data)\n", + " loss = criterion(output, target)\n", + " loss.backward()\n", + " train_loss += loss.item()\n", + " optimizer.step()\n", + " # Calculate correct predictions for training accuracy\n", + " pred = output.argmax(dim=1, keepdim=True)\n", + " correct += pred.eq(target.view_as(pred)).sum().item()\n", + "\n", + " train_loss /= len(train_loader.dataset)\n", + " train_accuracy = 100. * correct / len(train_loader.dataset)\n", + " if verbose:\n", + " print(f'{model.type}====> Epoch: {epoch} Average loss: {train_loss:.4f}')\n", + " print(f'{model.type}====> Epoch: {epoch} Training accuracy: {train_accuracy:.2f}%')\n", + " return train_loss\n", + "\n", + "\n", + "def test_classifier(model, test_loader, criterion):\n", + " model.eval()\n", + " eval_loss = 0\n", + " correct = 0\n", + " with torch.no_grad():\n", + " for i, (data, target) in enumerate(test_loader):\n", + " if model.type == 'cnn':\n", + " data = data.to(device)\n", + " else:\n", + " data = data.view(data.size(0), -1)\n", + " data = data.to(device)\n", + " target = target.to(device)\n", + " output = model(data)\n", + " eval_loss += criterion(output, target).item()\n", + " pred = output.argmax(dim=1, keepdim=True)\n", + " correct += pred.eq(target.view_as(pred)).sum().item()\n", + " eval_loss /= len(test_loader.dataset)\n", + " print('====> Test set loss: {:.4f}'.format(eval_loss))\n", + " accuracy = correct / len(test_loader.dataset)\n", + " print('====> Test set accuracy: {:.4f}'.format(accuracy))\n", + " return accuracy\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "dff05e622dcfd774", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:47:03.092600803Z", + "start_time": "2023-10-03T12:47:03.049595871Z" + } + }, + "outputs": [], + "source": [ + "# Define the training parameters for the fully connected classifier\n", + "batch_size = 32\n", + "epochs = 5\n", + "learning_rate = 1e-3\n", + "hidden_size = 128\n", + "num_classes = 10\n", + "train_frequency = epochs\n", + "test_frequency = epochs\n", + "# Create the fully connected classifier\n", + "input_size = trainset.data.shape[1] * trainset.data.shape[2] * trainset.data.shape[3]\n", + "classifier = DenseClassifier(input_size, hidden_size, num_classes).to(device)\n", + "input_size = trainset.data.shape[3]\n", + "cnn_classifier = CNNClassifier(input_size, hidden_size, num_classes).to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "3104345cdee0eb00", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:47:03.095971390Z", + "start_time": "2023-10-03T12:47:03.080480468Z" + } + }, + "outputs": [], + "source": [ + "# Define the loss function and optimizer\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(classifier.parameters(), lr=learning_rate)\n", + "optimizer_cnn = optim.Adam(cnn_classifier.parameters(), lr=learning_rate)\n" + ] + }, + { + "cell_type": "markdown", + "id": "e1fed39be2f04745", + "metadata": { + "collapsed": false + }, + "source": [ + "Train the non-linear classifiers" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "abc0c6ce338d40d9", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:49:41.536281525Z", + "start_time": "2023-10-03T12:47:03.094970008Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 80%|████████ | 4/5 [02:03<00:30, 30.65s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mlp====> Epoch: 5 Average loss: 0.0401\n", + "mlp====> Epoch: 5 Training accuracy: 55.44%\n", + "cnn====> Epoch: 5 Average loss: 0.0265\n", + "cnn====> Epoch: 5 Training accuracy: 69.87%\n", + "====> Test set loss: 0.0446\n", + "====> Test set accuracy: 0.5129\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 5/5 [02:38<00:00, 31.69s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "====> Test set loss: 0.0319\n", + "====> Test set accuracy: 0.6478\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Train the classifier\n", + "for epoch in tqdm(range(1, epochs + 1)):\n", + " verbose = True if epoch % train_frequency == 0 else False\n", + " train_classifier(classifier, train_loader, optimizer, criterion, epoch, verbose)\n", + " train_classifier(cnn_classifier, train_loader, optimizer_cnn, criterion, epoch, verbose)\n", + "\n", + " # test every n epochs\n", + " if epoch % test_frequency == 0:\n", + " test_acc = test_classifier(classifier, test_loader, criterion)\n", + " test_acc_cnn = test_classifier(cnn_classifier, test_loader, criterion)\n" + ] + }, + { + "cell_type": "markdown", + "id": "a06038f113d8434f", + "metadata": { + "collapsed": false + }, + "source": [ + "Load the encoder weights into the classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "6a91d8894b70ef7c", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:49:41.571816419Z", + "start_time": "2023-10-03T12:49:41.535695938Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "" + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# initialize the classifier with the encoder weights\n", + "classifier.encoder.load_state_dict(ae.encoder.state_dict())\n", + "cnn_classifier.encoder.load_state_dict(cnn_ae.encoder.state_dict())" + ] + }, + { + "cell_type": "markdown", + "id": "aafa4a9ba7208647", + "metadata": { + "collapsed": false + }, + "source": [ + "Transfer learning" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "a60dd68f988a8249", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:52:09.051066459Z", + "start_time": "2023-10-03T12:49:41.543118970Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 80%|████████ | 4/5 [02:06<00:31, 31.63s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mlp====> Epoch: 5 Average loss: 0.0539\n", + "mlp====> Epoch: 5 Training accuracy: 39.10%\n", + "cnn====> Epoch: 5 Average loss: 0.0333\n", + "cnn====> Epoch: 5 Training accuracy: 61.80%\n", + "====> Test set loss: 0.0536\n", + "====> Test set accuracy: 0.3961\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 5/5 [02:27<00:00, 29.50s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "====> Test set loss: 0.0341\n", + "====> Test set accuracy: 0.6133\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# fine-tune the classifier\n", + "learning_rate = 1e-5\n", + "epoch = 20\n", + "optimizer_pretrained = optim.Adam(classifier.parameters(), lr=learning_rate)\n", + "optimizer_pretrained_cnn = optim.Adam(cnn_classifier.parameters(), lr=learning_rate)\n", + "for epoch in tqdm(range(1, epochs + 1)):\n", + " verbose = True if epoch % train_frequency == 0 else False\n", + " train_classifier(classifier, train_loader, optimizer_pretrained, criterion, epoch, verbose)\n", + " train_classifier(cnn_classifier, train_loader, optimizer_cnn, criterion, epoch, verbose)\n", + "\n", + " # test every n epochs\n", + " if epoch % test_frequency == 0:\n", + " test_acc_pretrained = test_classifier(classifier, test_loader, criterion)\n", + " test_acc_pretrained_cnn = test_classifier(cnn_classifier, test_loader, criterion)\n" + ] + }, + { + "cell_type": "markdown", + "id": "31577275b833707a", + "metadata": { + "collapsed": false + }, + "source": [ + "Compare the results of the linear probing with the results of the linear classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "40d0e7f3f13404c9", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:52:09.054596162Z", + "start_time": "2023-10-03T12:52:09.050719737Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model Linear Accuracy Non-linear accuracy Pretrained accuracy \n", + "MLP AE 0.3933 0.5129 0.3961 \n", + "CNN AE 0.3981 0.6478 0.6133 \n" + ] + } + ], + "source": [ + "# print a table of the accuracies. compare the results with the results of the linear probing\n", + "print(f\"{'Model':<10} {'Linear Accuracy':<20} {'Non-linear accuracy':<20} {'Pretrained accuracy':<20}\")\n", + "print(f\"{'MLP AE':<10} {test_mlp[-1][1]:<20.4f} {test_acc:<20.4f} {test_acc_pretrained:<20.4f}\")\n", + "print(f\"{'CNN AE':<10} {test_cnn[-1][1]:<20.4f} {test_acc_cnn:<20.4f} {test_acc_pretrained_cnn:<20.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "outputs": [], + "source": [ + "import torchvision.models as models\n", + "\n", + "class ResNet18Autoencoder(nn.Module):\n", + " def __init__(self):\n", + " super(ResNet18Autoencoder, self).__init__()\n", + " self.type = 'cnn'\n", + " # Encoder: Use pre-trained ResNet18 (without its final fc layer)\n", + " self.resnet18 = models.resnet18(pretrained=False)\n", + " self.encoder = nn.Sequential(*list(self.resnet18.children())[:-1], nn.Flatten())\n", + " \n", + " # Decoder: Create an up-sampling network\n", + " self.decoder = nn.Sequential(\n", + " nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), # 8x8\n", + " nn.ReLU(),\n", + " nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 16x16\n", + " nn.ReLU(),\n", + " nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 32x32\n", + " nn.ReLU(),\n", + " nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 64x64\n", + " nn.ReLU(),\n", + " nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1), # 128x128\n", + " nn.Sigmoid() # to ensure pixel values are between 0 and 1\n", + " )\n", + " \n", + " def forward(self, x):\n", + " x = self.encoder(x)\n", + " # unflatten the output of the encoder to be fed into the decoder\n", + " x = x.view(x.size(0), 512, 1, 1)\n", + " x = self.decoder(x)\n", + " return x\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:52:09.068053637Z", + "start_time": "2023-10-03T12:52:09.054382586Z" + } + }, + "id": "e4037285d9d70694" + }, + { + "cell_type": "code", + "execution_count": 23, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ResNet18 AE parameters: 14476811\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", + " warnings.warn(\n", + "/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n", + " warnings.warn(msg)\n" + ] + } + ], + "source": [ + "# Define the training parameters for the ResNet18 Autoencoder\n", + "batch_size = 128\n", + "epochs = 10\n", + "learning_rate = 1e-3\n", + "train_frequency = epochs\n", + "test_frequency = epochs\n", + "\n", + "# Create the ResNet18 Autoencoder\n", + "resnet18_ae = ResNet18Autoencoder().to(device)\n", + "# print the model's number of parameters\n", + "print(f\"ResNet18 AE parameters: {sum(p.numel() for p in resnet18_ae.parameters())}\")\n", + "\n", + "# Define the loss function and optimizer\n", + "criterion = nn.MSELoss()\n", + "optimizer = optim.Adam(resnet18_ae.parameters(), lr=learning_rate)\n", + "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, verbose=True)\n", + "# Create the train and test dataloaders\n", + "train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)\n", + "test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:52:09.214178804Z", + "start_time": "2023-10-03T12:52:09.061558426Z" + } + }, + "id": "8963904ee52b9818" + }, + { + "cell_type": "code", + "execution_count": 24, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 10 - Train Loss: 0.0013 - Test Loss: 0.0013: 90%|█████████ | 9/10 [01:49<00:11, 11.05s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cnn====> Epoch: 10 Average loss: 0.0013\n", + "cnn Autoencoder\n", + "====> Test set loss: 0.0013\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy: 0.36802\n", + "Test accuracy: 0.3665\n", + "KNN accuracy: 0.3889\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 10 - Train Loss: 0.0013 - Test Loss: 0.0013: 100%|██████████| 10/10 [02:15<00:00, 13.52s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Clustering ARI score: 0.02700985553219709\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Train the model\n", + "test_resnet18 = []\n", + "loss = 0\n", + "pbar = tqdm(range(1, epochs + 1))\n", + "for epoch in pbar:\n", + " verbose = True if epoch % train_frequency == 0 else False\n", + " train_loss = train(resnet18_ae, train_loader, optimizer, criterion, epoch, verbose)\n", + " \n", + " # Update tqdm description with the training loss\n", + " pbar.set_description(f\"Epoch {epoch} - Train Loss: {train_loss:.4f} - Test Loss: {loss:.4f}\")\n", + " \n", + " if epoch % 2 == 0:\n", + " loss = test_loss(resnet18_ae, test_loader, criterion, verbose=False)\n", + " scheduler.step(loss)\n", + "\n", + " # test every n epochs\n", + " if epoch % test_frequency == 0:\n", + " results_resnet_dic = test(resnet18_ae, train_loader, test_loader, criterion)\n", + " test_resnet18.append([results_resnet_dic['reconstruction_loss'], results_resnet_dic['linear_classification_accuracy'], results_resnet_dic['knn_classification_accuracy'], results_resnet_dic['clustering_ari_score']])\n", + " \n", + "\n", + " " + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:54:24.395386156Z", + "start_time": "2023-10-03T12:52:09.200297772Z" + } + }, + "id": "7af2592ad350ea8c" + }, + { + "cell_type": "markdown", + "source": [ + "Compare the evaluation results of the ResNet18 Autoencoder with the MLP and CNN Autoencoders " + ], + "metadata": { + "collapsed": false + }, + "id": "18cc2efab5a2e4b0" + }, + { + "cell_type": "code", + "execution_count": 25, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model Reconstruction Loss Linear Accuracy KNN Accuracy Clustering ARI \n", + "MLP AE 0.0172 0.3933 0.3391 0.0397 \n", + "CNN AE 0.0044 0.3981 0.3545 0.0516 \n", + "ResNet18 AE 0.0013 0.3665 0.3889 0.0270 \n" + ] + } + ], + "source": [ + "print(f\"{'Model':<15} {'Reconstruction Loss':<20} {'Linear Accuracy':<20} {'KNN Accuracy':<20} {'Clustering ARI':<20}\")\n", + "print(f\"{'MLP AE':<15} {test_mlp[-1][0]:<20.4f} {test_mlp[-1][1]:<20.4f} {test_mlp[-1][2]:<20.4f} {test_mlp[-1][3]:<20.4f}\")\n", + "print(f\"{'CNN AE':<15} {test_cnn[-1][0]:<20.4f} {test_cnn[-1][1]:<20.4f} {test_cnn[-1][2]:<20.4f} {test_cnn[-1][3]:<20.4f}\")\n", + "print(f\"{'ResNet18 AE':<15} {test_resnet18[-1][0]:<20.4f} {test_resnet18[-1][1]:<20.4f} {test_resnet18[-1][2]:<20.4f} {test_resnet18[-1][3]:<20.4f}\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:54:24.401416457Z", + "start_time": "2023-10-03T12:54:24.394398143Z" + } + }, + "id": "a5dcaa09ceaf1f3f" + }, + { + "cell_type": "code", + "execution_count": 26, + "outputs": [], + "source": [ + "class ResNet18Classifier(nn.Module):\n", + " def __init__(self, num_classes=10, pretrained=False):\n", + " super(ResNet18Classifier, self).__init__()\n", + " self.type = 'cnn'\n", + " # Load the ResNet18 model\n", + " self.resnet18 = models.resnet18(pretrained=pretrained)\n", + " \n", + " # Adjust the first convolutional layer for CIFAR-10 image size\n", + " self.resnet18.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n", + " \n", + " # Adjust the final fully connected layer for CIFAR-10 number of classes\n", + " num_ftrs = self.resnet18.fc.in_features\n", + " self.resnet18.fc = nn.Linear(num_ftrs, num_classes)\n", + " \n", + " # Freeze the encoder weights except the final fc layer\n", + " for param in self.resnet18.parameters():\n", + " param.requires_grad = False\n", + " for param in self.resnet18.fc.parameters():\n", + " param.requires_grad = True\n", + " \n", + " def forward(self, x):\n", + " return self.resnet18(x)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:54:24.409917792Z", + "start_time": "2023-10-03T12:54:24.398893967Z" + } + }, + "id": "f4e2e62d832761c0" + }, + { + "cell_type": "code", + "execution_count": 27, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", + " warnings.warn(\n", + "/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n", + " warnings.warn(msg)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ResNet18 classifier parameters: 11173962\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n", + " warnings.warn(msg)\n" + ] + } + ], + "source": [ + "# Define the training parameters for the ResNet18 classifier\n", + "batch_size = 128\n", + "epochs = 50\n", + "learning_rate = 1e-4\n", + "train_frequency = epochs\n", + "test_frequency = epochs\n", + "\n", + "# Create the ResNet18 classifier\n", + "resnet18_classifier = ResNet18Classifier(num_classes=10, pretrained=False).to(device)\n", + "resnet18_classifier_pretrained = ResNet18Classifier(num_classes=10, pretrained=True).to(device)\n", + "# print the model's number of parameters\n", + "print(f\"ResNet18 classifier parameters: {sum(p.numel() for p in resnet18_classifier.parameters())}\")\n", + "\n", + "# Define the loss function and optimizer\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(resnet18_classifier.parameters(), lr=learning_rate)\n", + "optimizer_pretrained = optim.Adam(resnet18_classifier_pretrained.parameters(), lr=learning_rate)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:54:24.817741667Z", + "start_time": "2023-10-03T12:54:24.402570604Z" + } + }, + "id": "1f4d0fbf693153da" + }, + { + "cell_type": "code", + "execution_count": 28, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 49 - Train Loss: 0.0138 - Train Loss Pretrained: 0.0131: 98%|█████████▊| 49/50 [11:20<00:13, 13.33s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cnn====> Epoch: 50 Average loss: 0.0138\n", + "cnn====> Epoch: 50 Training accuracy: 37.49%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 50 - Train Loss: 0.0138 - Train Loss Pretrained: 0.0131: 98%|█████████▊| 49/50 [11:33<00:13, 13.33s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cnn====> Epoch: 50 Average loss: 0.0131\n", + "cnn====> Epoch: 50 Training accuracy: 41.32%\n", + "====> Test set loss: 0.0142\n", + "====> Test set accuracy: 0.3560\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 50 - Train Loss: 0.0138 - Train Loss Pretrained: 0.0131: 100%|██████████| 50/50 [11:36<00:00, 13.93s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "====> Test set loss: 0.0136\n", + "====> Test set accuracy: 0.3913\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Train the model\n", + "test_resnet18_classifier = []\n", + "test_resnet18_classifier_pretrained = []\n", + "pbar = tqdm(range(1, epochs + 1))\n", + "\n", + "for epoch in pbar:\n", + " verbose = True if epoch % train_frequency == 0 else False\n", + " train_loss = train_classifier(resnet18_classifier, train_loader, optimizer, criterion, epoch, verbose)\n", + " train_loss_pretrained = train_classifier(resnet18_classifier_pretrained, train_loader, optimizer_pretrained, criterion, epoch, verbose)\n", + " \n", + " # Update tqdm description with the training loss\n", + " pbar.set_description(f\"Epoch {epoch} - Train Loss: {train_loss:.4f} - Train Loss Pretrained: {train_loss_pretrained:.4f}\")\n", + " \n", + " # test every n epochs\n", + " if epoch % test_frequency == 0:\n", + " results_resnet_acc = test_classifier(resnet18_classifier, test_loader, criterion)\n", + " test_resnet18_classifier.append(results_resnet_acc)\n", + " results_resnet_acc_pretrained = test_classifier(resnet18_classifier_pretrained, test_loader, criterion)\n", + " test_resnet18_classifier_pretrained.append(results_resnet_acc_pretrained)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T13:06:01.403787907Z", + "start_time": "2023-10-03T12:54:24.817533512Z" + } + }, + "id": "8452e383237a283" + }, + { + "cell_type": "markdown", + "source": [ + "Compare the accuracy results of the ResNet18 classifier with the MLP and CNN classifiers" + ], + "metadata": { + "collapsed": false + }, + "id": "9b3b561b83de5ccc" + }, + { + "cell_type": "code", + "execution_count": 29, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model Accuracy \n", + "MLP AE 0.3933 \n", + "CNN AE 0.3981 \n", + "ResNet18 AE 0.3665 \n", + "ResNet18 Classifier 0.3560 \n", + "ResNet18 Classifier Pretrained 0.3913 \n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "print(f\"{'Model':<15} {'Accuracy':<20}\")\n", + "print(f\"{'MLP AE':<15} {test_mlp[-1][1]:<20.4f}\")\n", + "print(f\"{'CNN AE':<15} {test_cnn[-1][1]:<20.4f}\")\n", + "print(f\"{'ResNet18 AE':<15} {test_resnet18[-1][1]:<20.4f}\")\n", + "# take the average of the test accuracies\n", + "print(f\"{'ResNet18 Classifier':<15} {np.mean(test_resnet18_classifier):<20.4f}\")\n", + "print(f\"{'ResNet18 Classifier Pretrained':<15} {np.mean(test_resnet18_classifier_pretrained):<20.4f}\")\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T13:06:01.408109408Z", + "start_time": "2023-10-03T13:06:01.403113475Z" + } + }, + "id": "8d1f038afa9bcff2" + }, + { + "cell_type": "markdown", + "source": [], + "metadata": { + "collapsed": false + }, + "id": "9ff6e7674c4a3e71" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + }, + "id": "36fdbf815e16107b" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/autoencoder_mnist.ipynb b/autoencoder_mnist.ipynb new file mode 100644 index 0000000..ed96f88 --- /dev/null +++ b/autoencoder_mnist.ipynb @@ -0,0 +1,1135 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "| Credentials | |\n", + "|----|----------------------------------|\n", + "|Host | Montanuniversitaet Leoben |\n", + "|Web | https://cps.unileoben.ac.at |\n", + "|Mail | cps@unileoben.ac.at |\n", + "|Author | Fotios Lygerakis |\n", + "|Corresponding Authors | fotios.lygerakis@unileoben.ac.at |\n", + "|Last edited | 28.09.2023 |" + ], + "metadata": { + "collapsed": false + }, + "id": "ae041e151c5c2222" + }, + { + "cell_type": "markdown", + "source": [ + "In the first part of this tutorial we will build a fully connected MLP Autoencoder on the MNIST dataset. Then we will perform linear probing on the encoder features to see how well they perform on a linear classification task." + ], + "metadata": { + "collapsed": false + }, + "id": "1490260facaff836" + }, + { + "cell_type": "code", + "execution_count": 1, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import matplotlib.pyplot as plt\n", + "from torchvision import datasets, transforms\n", + "from sklearn.neighbors import KNeighborsClassifier\n", + "from sklearn.metrics import adjusted_rand_score\n", + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn.cluster import KMeans\n", + "from tqdm import tqdm" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:40:27.590742473Z", + "start_time": "2023-10-03T12:40:25.356175335Z" + } + }, + "id": "fc18830bb6f8d534" + }, + { + "cell_type": "markdown", + "source": [ + "Set random seed" + ], + "metadata": { + "collapsed": false + }, + "id": "1bad4bd03deb5b7e" + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [ + { + "data": { + "text/plain": "" + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Set random seed\n", + "torch.manual_seed(0)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:40:27.592356147Z", + "start_time": "2023-10-03T12:40:27.568457489Z" + } + }, + "id": "27dd48e60ae7dd9e" + }, + { + "cell_type": "markdown", + "source": [ + "Load the MNIST dataset" + ], + "metadata": { + "collapsed": false + }, + "id": "cc7f167a33227e94" + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [], + "source": [ + "# Define the transformations\n", + "transform = transforms.Compose([transforms.ToTensor(),\n", + " transforms.Normalize((0.5,), (0.5,))])\n", + "# Download and load the training data\n", + "trainset = datasets.MNIST('data', download=True, train=True, transform=transform)\n", + "# Download and load the test data\n", + "testset = datasets.MNIST('data', download=True, train=False, transform=transform)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:40:27.639417871Z", + "start_time": "2023-10-03T12:40:27.577605311Z" + } + }, + "id": "34248e8bc2678fd3" + }, + { + "cell_type": "markdown", + "source": [ + "Print some examples from the dataset" + ], + "metadata": { + "collapsed": false + }, + "id": "928dfac955d0d778" + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Get the first 10 samples\n", + "dataiter = iter(trainset)\n", + "images, labels = [], []\n", + "\n", + "for i in range(10):\n", + " image, label = next(dataiter)\n", + " images.append(image)\n", + " labels.append(label)\n", + "\n", + "# Plot the samples\n", + "fig, axes = plt.subplots(2, 5)\n", + "\n", + "for ax, img, lbl in zip(axes.ravel(), images, labels):\n", + " ax.imshow(img.squeeze().numpy(), cmap='gray')\n", + " ax.set_title(f'Label: {lbl}')\n", + " ax.axis('off')\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:40:28.606350630Z", + "start_time": "2023-10-03T12:40:28.277928820Z" + } + }, + "id": "87c6eae807f51118" + }, + { + "cell_type": "markdown", + "source": [ + "Define the MLP and Convolutional Autoencoder" + ], + "metadata": { + "collapsed": false + }, + "id": "e4e25962ef8e5b0d" + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [], + "source": [ + "class Autoencoder(nn.Module):\n", + " def __init__(self, input_size, hidden_size, type='mlp'):\n", + " super(Autoencoder, self).__init__()\n", + " # type of autoencoder\n", + " self.type = type\n", + " if self.type == 'mlp':\n", + " self.encoder = nn.Sequential(\n", + " nn.Linear(input_size, hidden_size),\n", + " nn.ReLU(True))\n", + " self.decoder = nn.Sequential(\n", + " nn.Linear(hidden_size, input_size),\n", + " nn.ReLU(True),\n", + " nn.Sigmoid()\n", + " )\n", + " elif self.type == 'cnn':\n", + " # Encoder module\n", + " self.encoder = nn.Sequential(\n", + " nn.Conv2d(in_channels=input_size, out_channels=hidden_size//2, kernel_size=3, stride=2, padding=1),\n", + " nn.ReLU(),\n", + " nn.Conv2d(in_channels=hidden_size//2, out_channels=hidden_size, kernel_size=3, stride=2, padding=1),\n", + " nn.ReLU()\n", + " )\n", + " # Decoder module\n", + " self.decoder = nn.Sequential(\n", + " nn.ConvTranspose2d(in_channels=hidden_size, out_channels=hidden_size//2, kernel_size=3, stride=2, padding=1, output_padding=1),\n", + " nn.ReLU(),\n", + " nn.ConvTranspose2d(in_channels=hidden_size//2, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1),\n", + " nn.Sigmoid() # Sigmoid to ensure the output is between 0 and 1\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.encoder(x)\n", + " x = self.decoder(x)\n", + " return x" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:40:29.561602021Z", + "start_time": "2023-10-03T12:40:29.559204154Z" + } + }, + "id": "26f2513d92b78e1e" + }, + { + "cell_type": "markdown", + "source": [ + "Check if GPU support is available" + ], + "metadata": { + "collapsed": false + }, + "id": "91a01313b4d95274" + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cuda\n" + ] + } + ], + "source": [ + "# device\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(device)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:40:30.834696308Z", + "start_time": "2023-10-03T12:40:30.827970016Z" + } + }, + "id": "67006b35b75d8dff" + }, + { + "cell_type": "markdown", + "source": [ + "Define the training function" + ], + "metadata": { + "collapsed": false + }, + "id": "8eebf70cb27640d5" + }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [], + "source": [ + "# Define the training function\n", + "def train(model, train_loader, optimizer, criterion, epoch, verbose=True):\n", + " model.train()\n", + " train_loss = 0\n", + " for i, (data, _) in enumerate(train_loader):\n", + " # check the type of autoencoder and modify the input data accordingly\n", + " if model.type == 'mlp':\n", + " data = data.view(data.size(0), -1)\n", + " data = data.to(device)\n", + " optimizer.zero_grad()\n", + " output = model(data)\n", + " loss = criterion(output, data)\n", + " loss.backward()\n", + " train_loss += loss.item()\n", + " optimizer.step()\n", + " train_loss /= len(train_loader.dataset)\n", + " if verbose:\n", + " print(f'{model.type}====> Epoch: {epoch} Average loss: {train_loss:.4f}') \n", + " return train_loss" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:40:31.675127463Z", + "start_time": "2023-10-03T12:40:31.655370005Z" + } + }, + "id": "5f96f7be13984747" + }, + { + "cell_type": "markdown", + "source": [ + "The evaluation functions for the linear classification and clustering tasks" + ], + "metadata": { + "collapsed": false + }, + "id": "5f6386edcab6b1e4" + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [], + "source": [ + "# Extract encoded representations for a given loader\n", + "def extract_features(loader, model):\n", + " features = []\n", + " labels = []\n", + " model.eval()\n", + " with torch.no_grad():\n", + " for data in loader:\n", + " img, label = data\n", + " if model.type == 'mlp':\n", + " img = img.view(img.size(0), -1)\n", + " img = img.to(device)\n", + " feature = model.encoder(img)\n", + " if model.type == 'cnn':\n", + " feature = feature.view(feature.size(0), -1) # Flatten the CNN encoded features\n", + " features.append(feature)\n", + " labels.append(label)\n", + " return torch.cat(features), torch.cat(labels)\n", + "\n", + "# Define the loss test function\n", + "def test_loss(model, test_loader, criterion):\n", + " model.eval()\n", + " eval_loss = 0\n", + " with torch.no_grad():\n", + " for i, (data, _) in enumerate(test_loader):\n", + " # check the type of autoencoder and modify the input data accordingly\n", + " if model.type == 'mlp':\n", + " data = data.view(data.size(0), -1)\n", + " data = data.to(device)\n", + " output = model(data)\n", + " eval_loss += criterion(output, data).item()\n", + " eval_loss /= len(test_loader.dataset)\n", + " print('====> Test set loss: {:.4f}'.format(eval_loss))\n", + " return eval_loss\n", + "\n", + "# Define the linear classification test function\n", + "def test_linear(encoded_train, train_labels, encoded_test, test_labels):\n", + " train_features_np = encoded_train.cpu().numpy()\n", + " train_labels_np = train_labels.cpu().numpy()\n", + " test_features_np = encoded_test.cpu().numpy()\n", + " test_labels_np = test_labels.cpu().numpy()\n", + " \n", + " # Apply logistic regression on train features and labels\n", + " logistic_regression = LogisticRegression(random_state=0, max_iter=100).fit(train_features_np, train_labels_np)\n", + " print(f\"Train accuracy: {logistic_regression.score(train_features_np, train_labels_np)}\")\n", + " # Apply logistic regression on test features and labels\n", + " test_accuracy = logistic_regression.score(test_features_np, test_labels_np)\n", + " print(f\"Test accuracy: {test_accuracy}\")\n", + " return test_accuracy\n", + "\n", + "\n", + "def test_clustering(encoded_features, true_labels):\n", + " encoded_features_np = encoded_features.cpu().numpy()\n", + " true_labels_np = true_labels.cpu().numpy()\n", + " \n", + " # Apply k-means clustering\n", + " kmeans = KMeans(n_clusters=10, n_init=10, random_state=0).fit(encoded_features_np)\n", + " cluster_labels = kmeans.labels_\n", + " \n", + " # Evaluate clustering results using Adjusted Rand Index\n", + " ari_score = adjusted_rand_score(true_labels_np, cluster_labels)\n", + " print(f\"Clustering ARI score: {ari_score}\")\n", + " return ari_score\n", + "\n", + "def knn_classifier(encoded_train, train_labels, encoded_test, test_labels, k=5):\n", + " encoded_train_np = encoded_train.cpu().numpy()\n", + " encoded_test_np = encoded_test.cpu().numpy()\n", + " train_labels_np = train_labels.cpu().numpy()\n", + " test_labels_np = test_labels.cpu().numpy()\n", + " \n", + " # Apply k-nearest neighbors classification\n", + " knn = KNeighborsClassifier(n_neighbors=k).fit(encoded_train_np, train_labels_np)\n", + " accuracy_score = knn.score(encoded_test_np, test_labels_np)\n", + " print(f\"KNN accuracy: {accuracy_score}\")\n", + " return accuracy_score\n", + "\n", + "def test(model, train_loader, test_loader, criterion):\n", + " # Extract features once for all tests\n", + " encoded_train, train_labels = extract_features(train_loader, model)\n", + " encoded_test, test_labels = extract_features(test_loader, model)\n", + " print(f\"{model.type} Autoencoder\")\n", + " results = {\n", + " 'reconstruction_loss': test_loss(model, test_loader, criterion),\n", + " 'linear_classification_accuracy': test_linear(encoded_train, train_labels, encoded_test, test_labels),\n", + " 'knn_classification_accuracy': knn_classifier(encoded_train, train_labels, encoded_test, test_labels),\n", + " 'clustering_ari_score': test_clustering(encoded_test, test_labels)\n", + " }\n", + " \n", + " # Save results to a log file\n", + " with open(\"evaluation_results.log\", \"w\") as log_file:\n", + " for key, value in results.items():\n", + " log_file.write(f\"{key}: {value}\")\n", + " \n", + " return results\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:40:32.662180572Z", + "start_time": "2023-10-03T12:40:32.657785583Z" + } + }, + "id": "b2c4483492fdd427" + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MLP AE parameters: 201616\n", + "CNN AE parameters: 148865\n" + ] + } + ], + "source": [ + "# Define the training parameters for the fully connected MLP Autoencoder\n", + "batch_size = 32\n", + "epochs = 5\n", + "hidden_size = 128\n", + "train_frequency = epochs\n", + "test_frequency = epochs\n", + "\n", + "# Create the fully connected MLP Autoencoder\n", + "input_size = trainset.data.shape[1] * trainset.data.shape[2]\n", + "ae = Autoencoder(input_size, hidden_size, type='mlp').to(device)\n", + "input_size=1\n", + "cnn_ae = Autoencoder(input_size, hidden_size, type='cnn').to(device)\n", + "# print the models' number of parameters\n", + "print(f\"MLP AE parameters: {sum(p.numel() for p in ae.parameters())}\")\n", + "print(f\"CNN AE parameters: {sum(p.numel() for p in cnn_ae.parameters())}\")\n", + "\n", + "# Define the loss function and optimizer\n", + "criterion = nn.MSELoss()\n", + "optimizer = optim.Adam(ae.parameters(), lr=1e-3)\n", + "optimizer_cnn = optim.Adam(cnn_ae.parameters(), lr=1e-3)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:40:33.593858793Z", + "start_time": "2023-10-03T12:40:33.153759806Z" + } + }, + "id": "bcb22bc5af9fb014" + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [], + "source": [ + "# Create the train and test dataloaders\n", + "train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)\n", + "test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:40:34.157176359Z", + "start_time": "2023-10-03T12:40:34.153720678Z" + } + }, + "id": "f1626ce45bb25883" + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 80%|████████ | 4/5 [01:02<00:15, 15.79s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mlp====> Epoch: 5 Average loss: 0.0598\n", + "cnn====> Epoch: 5 Average loss: 0.0260\n", + "mlp Autoencoder\n", + "====> Test set loss: 0.0598\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy: 0.25626666666666664\n", + "Test accuracy: 0.2649\n", + "KNN accuracy: 0.2295\n", + "Clustering ARI score: 0.0614873771495409\n", + "cnn Autoencoder\n", + "====> Test set loss: 0.0260\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy: 0.9316166666666666\n", + "Test accuracy: 0.9278\n", + "KNN accuracy: 0.9639\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 5/5 [06:56<00:00, 83.22s/it] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Clustering ARI score: 0.3909294873624941\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "test_mlp = []\n", + "test_cnn = []\n", + "# Train the model\n", + "for epoch in tqdm(range(1, epochs + 1)):\n", + " verbose = True if epoch % train_frequency == 0 else False\n", + " train(ae, train_loader, optimizer, criterion, epoch, verbose)\n", + " train(cnn_ae, train_loader, optimizer_cnn, criterion, epoch, verbose)\n", + "\n", + " # test every n epochs\n", + " if epoch % test_frequency == 0:\n", + " restults_dic = test(ae, train_loader, test_loader, criterion)\n", + " test_mlp.append([restults_dic['reconstruction_loss'], restults_dic['linear_classification_accuracy'], restults_dic['knn_classification_accuracy'], restults_dic['clustering_ari_score']])\n", + " restults_dic = test(cnn_ae, train_loader, test_loader, criterion)\n", + " test_cnn.append([restults_dic['reconstruction_loss'], restults_dic['linear_classification_accuracy'], restults_dic['knn_classification_accuracy'], restults_dic['clustering_ari_score']])" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:47:30.814501313Z", + "start_time": "2023-10-03T12:40:34.720164326Z" + } + }, + "id": "7472159fdc5f2532" + }, + { + "cell_type": "markdown", + "source": [ + "Compare the evaluation results of the MLP and CNN Autoencoders" + ], + "metadata": { + "collapsed": false + }, + "id": "10639256e342a159" + }, + { + "cell_type": "code", + "execution_count": 12, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model Reconstruction Loss Linear Accuracy KNN Accuracy Clustering ARI \n", + "MLP AE 0.0598 0.2649 0.2295 0.0615 \n", + "CNN AE 0.0260 0.9278 0.9639 0.3909 \n" + ] + } + ], + "source": [ + "print(f\"{'Model':<10} {'Reconstruction Loss':<20} {'Linear Accuracy':<20} {'KNN Accuracy':<20} {'Clustering ARI':<20}\")\n", + "print(f\"{'MLP AE':<10} {test_mlp[-1][0]:<20.4f} {test_mlp[-1][1]:<20.4f} {test_mlp[-1][2]:<20.4f} {test_mlp[-1][3]:<20.4f}\")\n", + "print(f\"{'CNN AE':<10} {test_cnn[-1][0]:<20.4f} {test_cnn[-1][1]:<20.4f} {test_cnn[-1][2]:<20.4f} {test_cnn[-1][3]:<20.4f}\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:47:30.828062767Z", + "start_time": "2023-10-03T12:47:30.812448850Z" + } + }, + "id": "50bb4c3c58af09ee" + }, + { + "cell_type": "markdown", + "source": [ + "Develop a linear classifier with fully connected layers" + ], + "metadata": { + "collapsed": false + }, + "id": "b9201d1403781706" + }, + { + "cell_type": "code", + "execution_count": 13, + "outputs": [], + "source": [ + "# Define the fully connected classifier for MNIST\n", + "class DenseClassifier(nn.Module):\n", + " def __init__(self, input_size=784, hidden_size=500, num_classes=10):\n", + " super(DenseClassifier, self).__init__()\n", + " self.type = 'mlp'\n", + " self.encoder = nn.Sequential(\n", + " nn.Linear(input_size, hidden_size),\n", + " nn.ReLU(True))\n", + " self.fc1 = nn.Linear(hidden_size, num_classes)\n", + "\n", + " def forward(self, x):\n", + " x = x.view(x.size(0), -1) # Flatten the input tensor\n", + " x = self.encoder(x)\n", + " x = self.fc1(x)\n", + " return x\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:47:30.833296890Z", + "start_time": "2023-10-03T12:47:30.819270525Z" + } + }, + "id": "1612800950703181" + }, + { + "cell_type": "markdown", + "source": [ + "Develop a non-linear classifier with convolutional layers" + ], + "metadata": { + "collapsed": false + }, + "id": "35db4190e9c7f716" + }, + { + "cell_type": "code", + "execution_count": 14, + "outputs": [], + "source": [ + "# cnn classifier\n", + "class CNNClassifier(nn.Module):\n", + " def __init__(self, input_size=3, hidden_size=128, num_classes=10):\n", + " super(CNNClassifier, self).__init__()\n", + " self.type = 'cnn'\n", + " # Encoder (Feature extractor)\n", + " self.encoder = nn.Sequential(\n", + " nn.Conv2d(in_channels=input_size, out_channels=hidden_size//2, kernel_size=3, stride=2, padding=1),\n", + " nn.ReLU(),\n", + " nn.Conv2d(in_channels=hidden_size//2, out_channels=hidden_size, kernel_size=3, stride=2, padding=1),\n", + " nn.ReLU()\n", + " )\n", + " \n", + " # Classifier\n", + " # Here, for the sake of example, I'm assuming the spatial size of the encoder output \n", + " # is 7x7 for an input size of 28x28. You might want to adjust this if the spatial dimensions change.\n", + " self.classifier = nn.Sequential(\n", + " nn.Flatten(),\n", + " nn.Linear(hidden_size*7*7, hidden_size),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_size, num_classes),\n", + " nn.LogSoftmax(dim=1) # LogSoftmax is typically used with NLLLoss\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.encoder(x)\n", + " x = self.classifier(x)\n", + " return x\n", + " return x" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:47:30.916320521Z", + "start_time": "2023-10-03T12:47:30.828281294Z" + } + }, + "id": "cb2dfcf75113fd0b" + }, + { + "cell_type": "markdown", + "source": [ + "Train and test functions for the non-linear classifier" + ], + "metadata": { + "collapsed": false + }, + "id": "7c3cf2371479da0c" + }, + { + "cell_type": "code", + "execution_count": 15, + "outputs": [], + "source": [ + "# Train for the classifier\n", + "def train_classifier(model, train_loader, optimizer, criterion, epoch, verbose=True):\n", + " model.train()\n", + " train_loss = 0\n", + " correct = 0 \n", + " for i, (data, target) in enumerate(train_loader):\n", + " if model.type == 'cnn':\n", + " data = data.to(device)\n", + " else:\n", + " data = data.view(data.size(0), -1)\n", + " data = data.to(device)\n", + " target = target.to(device)\n", + " optimizer.zero_grad()\n", + " output = model(data)\n", + " loss = criterion(output, target)\n", + " loss.backward()\n", + " train_loss += loss.item()\n", + " optimizer.step()\n", + " # Calculate correct predictions for training accuracy\n", + " pred = output.argmax(dim=1, keepdim=True)\n", + " correct += pred.eq(target.view_as(pred)).sum().item()\n", + "\n", + " train_loss /= len(train_loader.dataset)\n", + " train_accuracy = 100. * correct / len(train_loader.dataset)\n", + " if verbose:\n", + " print(f'{model.type}====> Epoch: {epoch} Average loss: {train_loss:.4f}')\n", + " print(f'{model.type}====> Epoch: {epoch} Training accuracy: {train_accuracy:.2f}%')\n", + " return train_loss\n", + "\n", + "\n", + "def test_classifier(model, test_loader, criterion):\n", + " model.eval()\n", + " eval_loss = 0\n", + " correct = 0\n", + " with torch.no_grad():\n", + " for i, (data, target) in enumerate(test_loader):\n", + " if model.type == 'cnn':\n", + " data = data.to(device)\n", + " else:\n", + " data = data.view(data.size(0), -1)\n", + " data = data.to(device)\n", + " target = target.to(device)\n", + " output = model(data)\n", + " eval_loss += criterion(output, target).item()\n", + " pred = output.argmax(dim=1, keepdim=True)\n", + " correct += pred.eq(target.view_as(pred)).sum().item()\n", + " eval_loss /= len(test_loader.dataset)\n", + " print('====> Test set loss: {:.4f}'.format(eval_loss))\n", + " accuracy = correct / len(test_loader.dataset)\n", + " print('====> Test set accuracy: {:.4f}'.format(accuracy))\n", + " return accuracy\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:47:30.924063759Z", + "start_time": "2023-10-03T12:47:30.875486950Z" + } + }, + "id": "ac980d25bd8a3dd3" + }, + { + "cell_type": "code", + "execution_count": 16, + "outputs": [], + "source": [ + "# Define the training parameters for the fully connected classifier\n", + "batch_size = 32\n", + "epochs = 5\n", + "learning_rate = 1e-3\n", + "hidden_size = 128\n", + "num_classes = 10\n", + "train_frequency = epochs\n", + "test_frequency = epochs\n", + "# Create the fully connected classifier\n", + "input_size = trainset.data.shape[1] * trainset.data.shape[2]\n", + "classifier = DenseClassifier(input_size, hidden_size, num_classes).to(device)\n", + "input_size = 1\n", + "cnn_classifier = CNNClassifier(input_size, hidden_size, num_classes).to(device)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:47:30.924544974Z", + "start_time": "2023-10-03T12:47:30.875705556Z" + } + }, + "id": "dff05e622dcfd774" + }, + { + "cell_type": "code", + "execution_count": 17, + "outputs": [], + "source": [ + "# Define the loss function and optimizer\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(classifier.parameters(), lr=learning_rate)\n", + "optimizer_cnn = optim.Adam(cnn_classifier.parameters(), lr=learning_rate)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:47:30.924773429Z", + "start_time": "2023-10-03T12:47:30.875787620Z" + } + }, + "id": "3104345cdee0eb00" + }, + { + "cell_type": "markdown", + "source": [ + "Train the non-linear classifiers" + ], + "metadata": { + "collapsed": false + }, + "id": "e1fed39be2f04745" + }, + { + "cell_type": "code", + "execution_count": 18, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 80%|████████ | 4/5 [02:10<00:32, 32.47s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mlp====> Epoch: 5 Average loss: 0.0030\n", + "mlp====> Epoch: 5 Training accuracy: 97.08%\n", + "cnn====> Epoch: 5 Average loss: 0.0005\n", + "cnn====> Epoch: 5 Training accuracy: 99.49%\n", + "====> Test set loss: 0.0035\n", + "====> Test set accuracy: 0.9686\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 5/5 [02:48<00:00, 33.62s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "====> Test set loss: 0.0013\n", + "====> Test set accuracy: 0.9880\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Train the classifier\n", + "for epoch in tqdm(range(1, epochs + 1)):\n", + " verbose = True if epoch % train_frequency == 0 else False\n", + " train_classifier(classifier, train_loader, optimizer, criterion, epoch, verbose)\n", + " train_classifier(cnn_classifier, train_loader, optimizer_cnn, criterion, epoch, verbose)\n", + "\n", + " # test every n epochs\n", + " if epoch % test_frequency == 0:\n", + " test_acc = test_classifier(classifier, test_loader, criterion)\n", + " test_acc_cnn = test_classifier(cnn_classifier, test_loader, criterion)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:50:19.005163249Z", + "start_time": "2023-10-03T12:47:30.875867176Z" + } + }, + "id": "abc0c6ce338d40d9" + }, + { + "cell_type": "markdown", + "source": [ + "Load the encoder weights into the classifier" + ], + "metadata": { + "collapsed": false + }, + "id": "a06038f113d8434f" + }, + { + "cell_type": "code", + "execution_count": 19, + "outputs": [ + { + "data": { + "text/plain": "" + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# initialize the classifier with the encoder weights\n", + "classifier.encoder.load_state_dict(ae.encoder.state_dict())\n", + "cnn_classifier.encoder.load_state_dict(cnn_ae.encoder.state_dict())" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:50:19.005816667Z", + "start_time": "2023-10-03T12:50:18.994175691Z" + } + }, + "id": "6a91d8894b70ef7c" + }, + { + "cell_type": "markdown", + "source": [ + "Transfer learning" + ], + "metadata": { + "collapsed": false + }, + "id": "aafa4a9ba7208647" + }, + { + "cell_type": "code", + "execution_count": 20, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 80%|████████ | 4/5 [01:56<00:27, 27.00s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mlp====> Epoch: 5 Average loss: 0.0547\n", + "mlp====> Epoch: 5 Training accuracy: 38.00%\n", + "cnn====> Epoch: 5 Average loss: 0.0004\n", + "cnn====> Epoch: 5 Training accuracy: 99.57%\n", + "====> Test set loss: 0.0526\n", + "====> Test set accuracy: 0.4150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 5/5 [02:25<00:00, 29.01s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "====> Test set loss: 0.0017\n", + "====> Test set accuracy: 0.9868\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# fine-tune the classifier\n", + "learning_rate = 1e-5\n", + "epochs = 5\n", + "train_frequency = epochs\n", + "test_frequency = epochs\n", + "optimizer_pretrained = optim.Adam(classifier.parameters(), lr=learning_rate)\n", + "optimizer_pretrained_cnn = optim.Adam(cnn_classifier.parameters(), lr=learning_rate)\n", + "for epoch in tqdm(range(1, epochs + 1)):\n", + " verbose = True if epoch % train_frequency == 0 else False\n", + " train_loss = train_classifier(classifier, train_loader, optimizer_pretrained, criterion, epoch, verbose)\n", + " train_loss_cnn = train_classifier(cnn_classifier, train_loader, optimizer_cnn, criterion, epoch, verbose)\n", + "\n", + " # test every n epochs\n", + " if epoch % test_frequency == 0:\n", + " test_acc_pretrained = test_classifier(classifier, test_loader, criterion)\n", + " test_acc_pretrained_cnn = test_classifier(cnn_classifier, test_loader, criterion)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:52:44.085979647Z", + "start_time": "2023-10-03T12:50:19.003728990Z" + } + }, + "id": "a60dd68f988a8249" + }, + { + "cell_type": "markdown", + "source": [ + "Compare the results of the linear probing with the results of the linear classifier" + ], + "metadata": { + "collapsed": false + }, + "id": "31577275b833707a" + }, + { + "cell_type": "code", + "execution_count": 21, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model Linear Accuracy Non-linear accuracy Pretrained accuracy \n", + "MLP AE 0.2649 0.9686 0.4150 \n", + "CNN AE 0.9278 0.9880 0.9868 \n" + ] + } + ], + "source": [ + "# print a table of the accuracies. compare the results with the results of the linear probing\n", + "print(f\"{'Model':<10} {'Linear Accuracy':<20} {'Non-linear accuracy':<20} {'Pretrained accuracy':<20}\")\n", + "print(f\"{'MLP AE':<10} {test_mlp[-1][1]:<20.4f} {test_acc:<20.4f} {test_acc_pretrained:<20.4f}\")\n", + "print(f\"{'CNN AE':<10} {test_cnn[-1][1]:<20.4f} {test_acc_cnn:<20.4f} {test_acc_pretrained_cnn:<20.4f}\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-10-03T12:52:44.091206610Z", + "start_time": "2023-10-03T12:52:44.084572508Z" + } + }, + "id": "40d0e7f3f13404c9" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + }, + "id": "f38a1ab6951a694e" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/conv_animation.ipynb b/conv_animation.ipynb new file mode 100644 index 0000000..abde85c --- /dev/null +++ b/conv_animation.ipynb @@ -0,0 +1,195 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "| Credentials | |\n", + "|----|----------------------------------|\n", + "|Host | Montanuniversitaet Leoben |\n", + "|Web | https://cps.unileoben.ac.at |\n", + "|Mail | cps@unileoben.ac.at |\n", + "|Author | Fotios Lygerakis |\n", + "|Corresponding Authors | fotios.lygerakis@unileoben.ac.at |\n", + "|Last edited | 28.09.2023 |" + ], + "metadata": { + "collapsed": false + }, + "id": "e94a99fbf273dd6e" + }, + { + "cell_type": "markdown", + "source": [ + "This notebook contains code for visualizing the convolution operation." + ], + "metadata": { + "collapsed": false + }, + "id": "ffa147d97adb2a8" + }, + { + "cell_type": "code", + "execution_count": 19, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.patches as patches\n", + "from matplotlib.animation import FuncAnimation, PillowWriter\n", + "\n", + "# Sample image (5x5)\n", + "image = np.array([\n", + " [1, 2, 3, 4, 5],\n", + " [5, 4, 3, 2, 1],\n", + " [1, 2, 3, 4, 5],\n", + " [5, 4, 3, 2, 1],\n", + " [1, 2, 3, 4, 5]\n", + "])\n", + "\n", + "# Kernel (3x3)\n", + "kernel = np.array([\n", + " [1, 0, -1],\n", + " [1, 0, -1],\n", + " [1, 0, -1]\n", + "])\n", + "# # kernel (2x2)\n", + "# kernel = np.array([\n", + "# [1, 0],\n", + "# [1, 0]\n", + "# ])\n", + "\n", + "# Stride and Padding\n", + "stride = 1\n", + "padding = 1\n", + "\n", + "# Create padded image\n", + "padded_image = np.pad(image, ((padding, padding), (padding, padding)))\n", + "\n", + "# Calculate output dimensions\n", + "output_dim = ((image.shape[0] - kernel.shape[0] + 2*padding) // stride) + 1\n", + "\n", + "# Initialize the convolution output\n", + "conv_output = np.zeros((output_dim, output_dim))\n", + "\n", + "# Calculate all possible top-left positions for the kernel\n", + "positions = [(y, x) for y in range(0, padded_image.shape[0] - kernel.shape[0] + 1, stride)\n", + " for x in range(0, padded_image.shape[1] - kernel.shape[1] + 1, stride)]\n", + "\n", + "# Set up the plotting with three subplots for the image, kernel, and convolution output\n", + "fig, (ax1, ax_kernel, ax2) = plt.subplots(1, 3, figsize=(15, 5))\n", + "ax1.imshow(padded_image, cmap='viridis', aspect='equal') # Use equal aspect for square pixels\n", + "ax_kernel.imshow(kernel, cmap='viridis', aspect='equal') # Display the kernel\n", + "ax2.imshow(conv_output, cmap='viridis', aspect='equal', vmin=-15, vmax=15) # Use equal aspect and set vmin, vmax for consistent color scaling\n", + "\n", + "# Add subplot titles\n", + "ax1.set_title(\"Input Image with Kernel Overlay\")\n", + "ax_kernel.set_title(\"Kernel\")\n", + "ax2.set_title(\"Convolution Output\")\n", + "\n", + "# Display numbers on the matrix for the padded image\n", + "for i in range(padded_image.shape[0]):\n", + " for j in range(padded_image.shape[1]):\n", + " ax1.text(j, i, str(padded_image[i, j]), ha='center', va='center', color='red')\n", + "\n", + "# Display numbers on the kernel\n", + "for i in range(kernel.shape[0]):\n", + " for j in range(kernel.shape[1]):\n", + " ax_kernel.text(j, i, str(kernel[i, j]), ha='center', va='center', color='red')\n", + "\n", + "# Kernel rectangle overlay on ax1\n", + "rect = patches.Rectangle((-0.5, -0.5), kernel.shape[1], kernel.shape[0], \n", + " linewidth=3, edgecolor='blue', facecolor='none')\n", + "ax1.add_patch(rect)\n", + "\n", + "# Animation function\n", + "def animate(i):\n", + " y, x = positions[i]\n", + " rect.set_xy((x-0.5, y-0.5))\n", + "\n", + " # Compute the convolution for the current position\n", + " region = padded_image[y:y+kernel.shape[0], x:x+kernel.shape[1]]\n", + " conv_value = np.sum(region * kernel)\n", + " \n", + " # Correctly compute the position in the output matrix\n", + " out_y = y // stride\n", + " out_x = x // stride\n", + " conv_output[out_y, out_x] = conv_value\n", + "\n", + " # Update the convolution output display\n", + " ax2.imshow(conv_output, cmap='viridis', aspect='equal', vmin=-15, vmax=15) # Adjusted colormap and vmin, vmax for better visualization\n", + "\n", + " # Display numbers on the matrix for the convolution output\n", + " for i in range(conv_output.shape[0]):\n", + " for j in range(conv_output.shape[1]):\n", + " ax2.text(j, i, f\"{conv_output[i, j]:.1f}\", ha='center', va='center', color='red')\n", + "\n", + " return rect,\n", + "\n", + "\n", + "# Create animation with increased interval for slower movement\n", + "ani = FuncAnimation(fig, animate, frames=len(positions), interval=2000, blit=True, repeat_delay=10000) # interval set to 5000 for slower movement\n", + "\n", + "# Title and layout\n", + "# title_text = f\"Kernel Size: {kernel.shape[0]}x{kernel.shape[1]}, Stride: {stride}, Padding: {padding}\"\n", + "# fig.suptitle(title_text, fontsize=16)\n", + "\n", + "# Adjust subplot spacing and layout\n", + "plt.subplots_adjust(wspace=0.5)\n", + "plt.tight_layout()\n", + "\n", + "# Save and show\n", + "writer = PillowWriter(fps=2)\n", + "ani.save(f\"convolution_animation_with_output_k{kernel.shape[0]}_s{stride}_p{padding}.gif\", writer=writer)\n", + "plt.show()\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-09-26T11:48:00.679662904Z", + "start_time": "2023-09-26T11:47:47.667621921Z" + } + }, + "id": "7e43a9eac318f3c3" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + }, + "id": "bce4c6c7bb659beb" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/evaluation_results.log b/evaluation_results.log new file mode 100644 index 0000000..7c62658 --- /dev/null +++ b/evaluation_results.log @@ -0,0 +1 @@ +reconstruction_loss: 0.001255071537196636linear_classification_accuracy: 0.3665knn_classification_accuracy: 0.3889clustering_ari_score: 0.02700985553219709 \ No newline at end of file