{ "cells": [ { "cell_type": "markdown", "source": [ "| Credentials | |\n", "|----|----------------------------------|\n", "|Host | Montanuniversitaet Leoben |\n", "|Web | https://cps.unileoben.ac.at |\n", "|Mail | cps@unileoben.ac.at |\n", "|Author | Fotios Lygerakis |\n", "|Corresponding Authors | fotios.lygerakis@unileoben.ac.at |\n", "|Last edited | 28.09.2023 |" ], "metadata": { "collapsed": false }, "id": "ae041e151c5c2222" }, { "cell_type": "markdown", "source": [ "In the first part of this tutorial we will build a fully connected MLP Autoencoder on the MNIST dataset. Then we will perform linear probing on the encoder features to see how well they perform on a linear classification task." ], "metadata": { "collapsed": false }, "id": "1490260facaff836" }, { "cell_type": "code", "execution_count": 1, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import matplotlib.pyplot as plt\n", "from torchvision import datasets, transforms\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.metrics import adjusted_rand_score\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.cluster import KMeans\n", "from tqdm import tqdm" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:40:27.590742473Z", "start_time": "2023-10-03T12:40:25.356175335Z" } }, "id": "fc18830bb6f8d534" }, { "cell_type": "markdown", "source": [ "Set random seed" ], "metadata": { "collapsed": false }, "id": "1bad4bd03deb5b7e" }, { "cell_type": "code", "execution_count": 2, "outputs": [ { "data": { "text/plain": "" }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Set random seed\n", "torch.manual_seed(0)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:40:27.592356147Z", "start_time": "2023-10-03T12:40:27.568457489Z" } }, "id": "27dd48e60ae7dd9e" }, { "cell_type": "markdown", "source": [ "Load the MNIST dataset" ], "metadata": { "collapsed": false }, "id": "cc7f167a33227e94" }, { "cell_type": "code", "execution_count": 3, "outputs": [], "source": [ "# Define the transformations\n", "transform = transforms.Compose([transforms.ToTensor(),\n", " transforms.Normalize((0.5,), (0.5,))])\n", "# Download and load the training data\n", "trainset = datasets.MNIST('data', download=True, train=True, transform=transform)\n", "# Download and load the test data\n", "testset = datasets.MNIST('data', download=True, train=False, transform=transform)\n" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:40:27.639417871Z", "start_time": "2023-10-03T12:40:27.577605311Z" } }, "id": "34248e8bc2678fd3" }, { "cell_type": "markdown", "source": [ "Print some examples from the dataset" ], "metadata": { "collapsed": false }, "id": "928dfac955d0d778" }, { "cell_type": "code", "execution_count": 4, "outputs": [ { "data": { "text/plain": "
", "image/png": "iVBORw0KGgoAAAANSUhEUgAAAnUAAAFiCAYAAACQzC7qAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAxUElEQVR4nO3dd3TUVfrH8ScUSegdFCWA9BUF6REBBQxNBKkqTUH9oZTlEEQswK7SpEgvigJZ2MPyAwLKYmElqLhsgFXYjRKMSIQgQqgBqTHf3x8e5pfnJkwyyUwmc+f9Oodz5jPfmfne5Cbx8TvP3BviOI4jAAAACGiF/D0AAAAA5B1FHQAAgAUo6gAAACxAUQcAAGABijoAAAALUNQBAABYgKIOAADAAhR1AAAAFqCoAwAAsIB1RV1SUpKEhITI7NmzvfaaO3fulJCQENm5c6fXXhPew5wHH+Y8+DDnwYc591yBKOpWrVolISEhsm/fPn8PxSemTJkiISEhmf6Fhob6e2h+Y/uci4gcP35c+vXrJ2XLlpXSpUvLY489Jj/++KO/h+U3wTDnGXXq1ElCQkJk5MiR/h6K39g+54cOHZKxY8dKRESEhIaGSkhIiCQlJfl7WH5l+5yLiKxbt07uv/9+CQ0NlUqVKsmwYcPk9OnT/h6WiIgU8fcAgsnSpUulZMmSrly4cGE/jga+dOnSJXnooYfkwoUL8sorr0jRokXl7bfflnbt2sn+/fulQoUK/h4ifGjTpk2ye/dufw8DPrZ7925ZsGCBNGzYUBo0aCD79+/395DgY0uXLpUXXnhBOnToIHPnzpXk5GSZP3++7Nu3T+Li4vx+sYaiLh/16dNHKlas6O9hIB8sWbJEEhMTZc+ePdK8eXMREenSpYvcc889MmfOHJk2bZqfRwhfuXr1qowbN04mTJggkyZN8vdw4EM9evSQ8+fPS6lSpWT27NkUdZa7fv26vPLKK9K2bVvZvn27hISEiIhIRESEPProo/Luu+/KqFGj/DrGAvH2a05cv35dJk2aJE2bNpUyZcpIiRIl5MEHH5TY2NhbPuftt9+W8PBwCQsLk3bt2kl8fHymxyQkJEifPn2kfPnyEhoaKs2aNZMPPvgg2/FcvnxZEhISPLrk6jiOpKamiuM4OX5OMAvkOd+wYYM0b97cVdCJiNSvX186dOgg69evz/b5wSqQ5/ymt956S9LT0yUqKirHzwlmgTzn5cuXl1KlSmX7OGiBOufx8fFy/vx56d+/v6ugExHp3r27lCxZUtatW5ftuXwtYIq61NRUWbFihbRv315mzpwpU6ZMkZSUFImMjMzy/46io6NlwYIF8uKLL8rEiRMlPj5eHn74YTl58qTrMd9++620atVKDh48KC+//LLMmTNHSpQoIT179pSYmBi349mzZ480aNBAFi1alOOvoVatWlKmTBkpVaqUDBw4UI0FmQXqnKenp8t//vMfadasWaZjLVq0kMOHD8vFixdz9k0IMoE65zcdPXpUZsyYITNnzpSwsDCPvvZgFehzDs8F6pxfu3ZNRCTL3+2wsDD55ptvJD09PQffAR9yCoCVK1c6IuLs3bv3lo9JS0tzrl27pu47d+6cU6VKFeeZZ55x3XfkyBFHRJywsDAnOTnZdX9cXJwjIs7YsWNd93Xo0MFp1KiRc/XqVdd96enpTkREhFOnTh3XfbGxsY6IOLGxsZnumzx5crZf37x585yRI0c6a9eudTZs2OCMGTPGKVKkiFOnTh3nwoUL2T7fRjbPeUpKiiMizp///OdMxxYvXuyIiJOQkOD2NWxk85zf1KdPHyciIsKVRcR58cUXc/RcGwXDnN80a9YsR0ScI0eOePQ829g85ykpKU5ISIgzbNgwdX9CQoIjIo6IOKdPn3b7Gr4WMFfqChcuLLfddpuI/H4l5OzZs5KWlibNmjWTr7/+OtPje/bsKdWqVXPlFi1aSMuWLWXbtm0iInL27FnZsWOH9OvXTy5evCinT5+W06dPy5kzZyQyMlISExPl+PHjtxxP+/btxXEcmTJlSrZjHzNmjCxcuFCefPJJ6d27t8ybN09Wr14tiYmJsmTJEg+/E8EjUOf8ypUrIiJSrFixTMduNtHefAy0QJ1zEZHY2FjZuHGjzJs3z7MvOsgF8pwjdwJ1zitWrCj9+vWT1atXy5w5c+THH3+UL7/8Uvr37y9FixYVEf//bQ+Yok5EZPXq1XLvvfdKaGioVKhQQSpVqiR///vf5cKFC5keW6dOnUz31a1b1/Vx8x9++EEcx5HXX39dKlWqpP5NnjxZREROnTrls6/lySeflKpVq8o//vEPn53DBoE45zcvzd+8VJ/R1atX1WOQWSDOeVpamowePVoGDRqk+iiRM4E458ibQJ3z5cuXS9euXSUqKkruvvtuadu2rTRq1EgeffRRERG1woU/BMynX9esWSNDhw6Vnj17yvjx46Vy5cpSuHBhmT59uhw+fNjj17v5vndUVJRERkZm+ZjatWvnaczZueuuu+Ts2bM+PUcgC9Q5L1++vBQrVkxOnDiR6djN++644448n8dGgTrn0dHRcujQIVm+fHmmdcouXrwoSUlJUrlyZSlevHiez2WbQJ1z5F4gz3mZMmVky5YtcvToUUlKSpLw8HAJDw+XiIgIqVSpkpQtW9Yr58mtgCnqNmzYILVq1ZJNmzapT53crMJNiYmJme77/vvvpUaNGiLy+4cWRESKFi0qHTt29P6As+E4jiQlJUmTJk3y/dyBIlDnvFChQtKoUaMsF9+Mi4uTWrVq8Ym5WwjUOT969KjcuHFDHnjggUzHoqOjJTo6WmJiYqRnz54+G0OgCtQ5R+7ZMOfVq1eX6tWri4jI+fPn5d///rf07t07X87tTsC8/XpzoV4nw3IgcXFxt1zgc/Pmzeo99D179khcXJx06dJFREQqV64s7du3l+XLl2d5RSUlJcXteDz52HtWr7V06VJJSUmRzp07Z/v8YBXIc96nTx/Zu3evKuwOHTokO3bskL59+2b7/GAVqHM+YMAAiYmJyfRPRKRr164SExMjLVu2dPsawSpQ5xy5Z9ucT5w4UdLS0mTs2LG5er43Fagrde+//758/PHHme4fM2aMdO/eXTZt2iS9evWSbt26yZEjR2TZsmXSsGFDuXTpUqbn1K5dW9q0aSMjRoyQa9euybx586RChQry0ksvuR6zePFiadOmjTRq1EieffZZqVWrlpw8eVJ2794tycnJcuDAgVuOdc+ePfLQQw/J5MmTs22uDA8Pl/79+0ujRo0kNDRUdu3aJevWrZPGjRvL888/n/NvkIVsnfMXXnhB3n33XenWrZtERUVJ0aJFZe7cuVKlShUZN25czr9BFrJxzuvXry/169fP8ljNmjWD/gqdjXMuInLhwgVZuHChiIh89dVXIiKyaNEiKVu2rJQtWzaot4izdc5nzJgh8fHx0rJlSylSpIhs3rxZPv30U3nzzTcLRj9t/n/gNrObH4G+1b9jx4456enpzrRp05zw8HCnWLFiTpMmTZytW7c6Q4YMccLDw12vdfMj0LNmzXLmzJnj3HXXXU6xYsWcBx980Dlw4ECmcx8+fNgZPHiwU7VqVado0aJOtWrVnO7duzsbNmxwPSavH3sfPny407BhQ6dUqVJO0aJFndq1azsTJkxwUlNT8/JtC2i2z7njOM6xY8ecPn36OKVLl3ZKlizpdO/e3UlMTMzttyzgBcOcm4QlTaye85tjyupfxrEHE9vnfOvWrU6LFi2cUqVKOcWLF3datWrlrF+/Pi/fMq8KcRy2NwAAAAh0AdNTBwAAgFujqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWCDHO0pk3J8NgSMvyxAy54GJOQ8+zHnwYc6DT07mnCt1AAAAFqCoAwAAsABFHQAAgAUo6gAAACxAUQcAAGABijoAAAALUNQBAABYgKIOAADAAhR1AAAAFqCoAwAAsABFHQAAgAUo6gAAACxAUQcAAGABijoAAAALUNQBAABYoIi/BwAUBE2bNlV55MiRKg8ePFjl6OholRcuXKjy119/7cXRAQCQPa7UAQAAWICiDgAAwAIUdQAAABYIcRzHydEDQ0J8PRafKVy4sMplypTx6Plmf1Xx4sVVrlevnsovvviiyrNnz3bdfuKJJ9Sxq1evqjxjxgyV//SnP3k0VlMOpzdLgTzn2WncuLHKO3bsULl06dIevd6FCxdUrlChQq7G5Q3MuX906NDBdXvt2rXqWLt27VQ+dOiQV8/NnPvGa6+9prL597hQof+/LtK+fXt17PPPP/fZuESY82CUkznnSh0AAIAFKOoAAAAsQFEHAABggYBYp6569eoq33bbbSpHRESo3KZNG5XLli2rcu/evb03OBFJTk5WecGCBSr36tXLdfvixYvq2IEDB1T2dR9GsGrRooXKGzduVNnsszR7F8x5u379uspmD12rVq1ct80168zn2qRt27Yqm9+XmJiY/BxOvmrevLnr9t69e/04EuTW0KFDVZ4wYYLK6enpt3xuXnrcAG/hSh0AAIAFKOoAAAAsUCDffs1uuQlPlyTxNvMSvPmx90uXLqmccXmDEydOqGPnzp1T2dtLHQQLc5mZ+++/X+U1a9aofPvtt3v0+omJiSq/9dZbKq9bt07lr776ynXb/PmYPn26R+cOJOayDnXq1FHZprdfMy5nISJSs2ZN1+3w8HB1jCUkAoM5b6GhoX4aCW6lZcuWKg8cOFBlc/mgP/zhD25fLyoqSuWff/5ZZbOdK+N/S+Li4twP1g+4UgcAAGABijoAAAALUNQBAABYoED21B09elTlM2fOqOztnjrzffHz58+r/NBDD6lsLknxl7/8xavjgeeWL1+usrkdW16ZPXolS5ZU2VyKJmNv2b333uvVsRRkgwcPVnn37t1+GonvmX2Zzz77rOu22cOZkJCQL2OCZzp27KjyqFGj3D7enMfu3bu7bp88edJ7A4NL//79VZ4/f77KFStWVNnsX925c6fKlSpVUnnWrFluz2++XsbnDxgwwO1z/YErdQAAABagqAMAALAARR0AAIAFCmRP3dmzZ1UeP368yhn7GEREvvnmG5XNbbpM+/fvV7lTp04q//rrryqb69yMGTPG7evD95o2bapyt27dVM5uXTCzB+7DDz9Uefbs2SqbaxeZP3PmeoMPP/xwjsdiE3PtNputWLHilsfMdQ1RMJhrjq1cuVLl7Pq1zf6rn376yTsDC2JFiugypFmzZiq/++67Kptrkn7xxRcqv/HGGyrv2rVL5WLFiqm8fv16lR955BG34923b5/b4/4WPH+BAQAALEZRBwAAYAGKOgAAAAsUyJ460+bNm1U294K9ePGiyvfdd5/Kw4YNU9nslzJ76Ezffvutys8995zbx8P7zP2At2/frnLp0qVVdhxH5Y8++khlcx07c79Ac79Ws38qJSVF5QMHDqiccX9gs9/PXPPu66+/lkBlrsFXpUoVP40k/7nrvzJ/PlEwDBkyROU77rjD7ePNNc6io6O9PaSgZ+7d6q5XVSTz75a5jl1qaqrb55uPz66HLjk5WeXVq1e7fby/caUOAADAAhR1AAAAFqCoAwAAsEBA9NSZsnvP/MKFC26PZ9yjUUTkb3/7m8oZ+6HgH3Xr1lXZXKvQ7Gc6ffq0yidOnFDZ7IO4dOmSyn//+9/d5rwICwtTedy4cSo/9dRTXjtXfuvatavK5tdqE7NfsGbNmrd87PHjx309HOSAuS/oM888o7L5t97c9/vNN9/0ybiCmbmO3CuvvKKy2Q+9ZMkSlc1+5+zqAdOrr77q0eNHjx6tstlPXdBwpQ4AAMACFHUAAAAWoKgDAACwQED21GVnypQpKpv7hJprknXs2FHlTz/91Cfjwq2Z+/GZawmavVvm2oSDBw9W2dyfryD1elWvXt3fQ/CaevXquT1urvEYyMyfSbPH7vvvv3fdNn8+kT9q1Kih8saNGz16/sKFC1WOjY3N65CC3qRJk1Q2e+iuX7+u8ieffKLyhAkTVL5y5Yrb84WGhqpsrkNn/v019+Y2+yi3bNni9nwFDVfqAAAALEBRBwAAYAGKOgAAAAtY2VNn7uVqrktn7rX57rvvqmz2UZj9WYsXL1bZXFcHnmvSpInKZg+d6bHHHlP5888/9/qYkHd79+719xBuydwvuHPnziqbe1Jmt0dkxvW3zPXOkD/MOTT3JjZ99tlnKs+fP9/rYwpGZcuWdd1+4YUX1DHzv5dmD13Pnj09Olft2rVVXrt2rcpmT71pw4YNKr/11lsenb+g4UodAACABSjqAAAALEBRBwAAYAEre+pMhw8fVnno0KEqr1y5UuVBgwa5zSVKlFA5OjpaZXPfUWRv7ty5KptrB5k9cwW9h65Qof///6Vg3ku4fPnyuX7ufffdp7L5M2GuL3nnnXeqfNttt6ls7rGbcY5EMq9/FRcXp/K1a9dULlJE//n897//LchfZv/VjBkz3D5+165dKg8ZMkTl7PYNR85k/N0z9981mXurVq5cWeWnn35a5R49eqh8zz33qFyyZEmVzR4+M69Zs0Zlsyc/0HClDgAAwAIUdQAAABagqAMAALBAUPTUmWJiYlROTExU2ezv6tChg8rTpk1TOTw8XOWpU6eqfPz48VyN03bdu3d33W7cuLE6ZvY9fPDBB/kxJK/J2Ednfi379+/P59H4jtmHZn6ty5YtU9nc99Edc40xs6cuLS1N5cuXL6v83Xffqfz++++rbK4/afZpnjx5UuXk5GSVzf2EExISBL6V171df/zxR5XNOYZ3ZNzPNSUlRR2rVKmSykeOHFHZ03Vff/75Z5VTU1NVvv3221U+ffq0yh9++KFH5yvouFIHAABgAYo6AAAAC1DUAQAAWCAoe+pM8fHxKvfr10/lRx99VGVzXbvnn39e5Tp16qjcqVOnvA7RShl7ksw1xU6dOqXy3/72t3wZU04VK1ZM5SlTptzysTt27FB54sSJvhiSX5j7Ov70008qR0RE5Pq1jx49qvLmzZtVPnjwoMr/+te/cn2urDz33HMqm71AZn8WfG/ChAkqe7oGZHbr2ME7Mu59bK4luHXrVpXNtSzNdWW3bNmi8qpVq1Q+e/asyuvWrVPZ7Kkzj9uGK3UAAAAWoKgDAACwAEUdAACABeipy0LGfgARkb/85S8qr1ixQmVzD8i2bduq3L59e9ftnTt35nl8wcDcZ9Pf++maPXSvvfaayuPHj1c545pmc+bMUccuXbrk5dEVHDNnzvT3ELzGXJ/S5OkaafCcuX7lI4884tHzzX6sQ4cO5XVI8JC5h7LZm5pX5n9v27Vrp7LZd2l7LyxX6gAAACxAUQcAAGABijoAAAAL0FMnmfeY7NOnj8rNmzdX2eyhM5l7Tn7xxRd5GF1w8vder2Yvj9kz179/f5XN3p3evXv7ZFwoOMw9pOF9n376qcrlypVz+3hzrcKhQ4d6e0goYMw9mM0eOnMvWdapAwAAQIFHUQcAAGABijoAAAALBEVPXb169VQeOXKkyo8//rjKVatW9ej1f/vtN5XNNdU83Z8wWISEhGR5WyTzfoFjxozx6VjGjh2r8uuvv65ymTJlVF67dq3KgwcP9s3AgCBWoUIFlbP7W7pkyRKVbV4TEr/75JNP/D2EAoUrdQAAABagqAMAALAARR0AAIAFrOipM3vgnnjiCZXNHroaNWrk6Xz79u1TeerUqSr7e421QJFx/SBzLSFzThcsWKDy+++/r/KZM2dUbtWqlcqDBg1S+b777lP5zjvvVPno0aMqm30bZu8O7Gf2fdatW1dlc400eG7lypUqFyrk2XWHf/7zn94cDgJAZGSkv4dQoHClDgAAwAIUdQAAABYIiLdfq1SponLDhg1VXrRokcr169fP0/ni4uJUnjVrlsrmllAsWeJ9hQsXVvmFF15Q2dyGKzU1VeU6dep4dD7zbZvY2FiVJ02a5NHrwT5mi4Cnbw0iM3M7vo4dO6ps/m29fv26yosXL1b55MmT3hscAkKtWrX8PYQChb9KAAAAFqCoAwAAsABFHQAAgAUKTE9d+fLlXbeXL1+ujpl9F3l9D93sn5ozZ47K5vIVV65cydP5kLXdu3e7bu/du1cda968udvnmkuemH2XJnPJk3Xr1qns623IYJ/WrVurvGrVKv8MJICVLVtW5ey2aDx+/LjKUVFR3h4SAsyXX36pstnrGmw971ypAwAAsABFHQAAgAUo6gAAACyQbz11LVu2VHn8+PEqt2jRwnW7WrVqeTrX5cuXVTa3mJo2bZrKv/76a57Oh9xJTk523X788cfVseeff17l1157zaPXnj9/vspLly5V+YcffvDo9QBzmzAA/hcfH69yYmKiymYP/t13361ySkqKbwbmJ1ypAwAAsABFHQAAgAUo6gAAACyQbz11vXr1cpvd+e6771TeunWrymlpaSqb686dP38+x+eCf5w4cULlKVOmuM2Ar3300Ucq9+3b108jsVdCQoLK5hqibdq0yc/hwAJmz/yKFStUnjp1qsqjRo1S2aw3Ag1X6gAAACxAUQcAAGABijoAAAALhDiO4+TogazRFJByOL1ZYs4DE3MefJjz4MOcZ6106dIqr1+/XuWOHTuqvGnTJpWffvpplQvSOrY5mXOu1AEAAFiAog4AAMACFHUAAAAWoKfOcvRdBB/mPPgw58GHOc8Zs8fOXKduxIgRKt97770qF6R16+ipAwAACBIUdQAAABagqAMAALAAPXWWo+8i+DDnwYc5Dz7MefChpw4AACBIUNQBAABYgKIOAADAAjnuqQMAAEDBxZU6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALCAdUVdUlKShISEyOzZs732mjt37pSQkBDZuXOn114T3sOcBx/mPPgw58GHOfdcgSjqVq1aJSEhIbJv3z5/D8UnNm3aJP3795datWpJ8eLFpV69ejJu3Dg5f/68v4fmN7bP+aFDh2Ts2LESEREhoaGhEhISIklJSf4ell/ZPucxMTESGRkpd9xxhxQrVkzuvPNO6dOnj8THx/t7aH5j+5zze56Z7XNu6tSpk4SEhMjIkSP9PRQRKSBFne2ee+45OXjwoAwcOFAWLFggnTt3lkWLFknr1q3lypUr/h4efGD37t2yYMECuXjxojRo0MDfw0E++O9//yvlypWTMWPGyJIlS2TEiBHyzTffSIsWLeTAgQP+Hh58gN/z4LZp0ybZvXu3v4ehFPH3AILBhg0bpH379uq+pk2bypAhQ2Tt2rUyfPhw/wwMPtOjRw85f/68lCpVSmbPni379+/395DgY5MmTcp03/Dhw+XOO++UpUuXyrJly/wwKvgSv+fB6+rVqzJu3DiZMGFClr/7/hIwV+quX78ukyZNkqZNm0qZMmWkRIkS8uCDD0psbOwtn/P2229LeHi4hIWFSbt27bJ8GyQhIUH69Okj5cuXl9DQUGnWrJl88MEH2Y7n8uXLkpCQIKdPn872sWZBJyLSq1cvERE5ePBgts8PVoE85+XLl5dSpUpl+zhogTznWalcubIUL148qFstshPIc87vee4E8pzf9NZbb0l6erpERUXl+Dn5IWCKutTUVFmxYoW0b99eZs6cKVOmTJGUlBSJjIzM8v+OoqOjZcGCBfLiiy/KxIkTJT4+Xh5++GE5efKk6zHffvuttGrVSg4ePCgvv/yyzJkzR0qUKCE9e/aUmJgYt+PZs2ePNGjQQBYtWpSrr+eXX34REZGKFSvm6vnBwLY5R/ZsmPPz589LSkqK/Pe//5Xhw4dLamqqdOjQIcfPDzY2zDk8E+hzfvToUZkxY4bMnDlTwsLCPPrafc4pAFauXOmIiLN3795bPiYtLc25du2auu/cuXNOlSpVnGeeecZ135EjRxwRccLCwpzk5GTX/XFxcY6IOGPHjnXd16FDB6dRo0bO1atXXfelp6c7ERERTp06dVz3xcbGOiLixMbGZrpv8uTJufmSnWHDhjmFCxd2vv/++1w9P9AF05zPmjXLERHnyJEjHj3PNsEy5/Xq1XNExBERp2TJks5rr73m/Pbbbzl+vk2CZc4dh9/zm4Jhzvv06eNERES4sog4L774Yo6e62sBc6WucOHCctttt4mISHp6upw9e1bS0tKkWbNm8vXXX2d6fM+ePaVatWqu3KJFC2nZsqVs27ZNRETOnj0rO3bskH79+snFixfl9OnTcvr0aTlz5oxERkZKYmKiHD9+/Jbjad++vTiOI1OmTPH4a/nrX/8q7733nowbN07q1Knj8fODhU1zjpyxYc5XrlwpH3/8sSxZskQaNGggV65ckd9++y3Hzw82Nsw5PBPIcx4bGysbN26UefPmefZF55OA+qDE6tWrZc6cOZKQkCA3btxw3V+zZs1Mj82qWKpbt66sX79eRER++OEHcRxHXn/9dXn99dezPN+pU6fUD5I3fPnllzJs2DCJjIyUqVOnevW1bWTDnMMzgT7nrVu3dt0eMGCA61OR3lxryzaBPufwXCDOeVpamowePVoGDRokzZs3z9Nr+UrAFHVr1qyRoUOHSs+ePWX8+PFSuXJlKVy4sEyfPl0OHz7s8eulp6eLiEhUVJRERkZm+ZjatWvnacymAwcOSI8ePeSee+6RDRs2SJEiAfPt9wsb5hyesW3Oy5UrJw8//LCsXbuWou4WbJtzZC9Q5zw6OloOHToky5cvz7Qe4cWLFyUpKcn14Sh/CZiqYsOGDVKrVi3ZtGmThISEuO6fPHlylo9PTEzMdN/3338vNWrUEBGRWrVqiYhI0aJFpWPHjt4fsOHw4cPSuXNnqVy5smzbtk1Klizp83MGukCfc3jOxjm/cuWKXLhwwS/nDgQ2zjncC9Q5P3r0qNy4cUMeeOCBTMeio6MlOjpaYmJipGfPnj4bQ3YCqqdORMRxHNd9cXFxt1z4b/Pmzeo99D179khcXJx06dJFRH5faqB9+/ayfPlyOXHiRKbnp6SkuB2PJx+B/uWXX+SRRx6RQoUKySeffCKVKlXK9jkI7DlH7gTynJ86dSrTfUlJSfLZZ59Js2bNsn1+sArkOUfuBOqcDxgwQGJiYjL9ExHp2rWrxMTESMuWLd2+hq8VqCt177//vnz88ceZ7h8zZox0795dNm3aJL169ZJu3brJkSNHZNmyZdKwYUO5dOlSpufUrl1b2rRpIyNGjJBr167JvHnzpEKFCvLSSy+5HrN48WJp06aNNGrUSJ599lmpVauWnDx5Unbv3i3JycluV4Hfs2ePPPTQQzJ58uRsmys7d+4sP/74o7z00kuya9cu2bVrl+tYlSpVpFOnTjn47tjJ1jm/cOGCLFy4UEREvvrqKxERWbRokZQtW1bKli1bYLaU8Qdb57xRo0bSoUMHady4sZQrV04SExPlvffekxs3bsiMGTNy/g2ykK1zzu/5rdk45/Xr15f69etneaxmzZp+vULn4odP3GZy8yPQt/p37NgxJz093Zk2bZoTHh7uFCtWzGnSpImzdetWZ8iQIU54eLjrtW5+BHrWrFnOnDlznLvuusspVqyY8+CDDzoHDhzIdO7Dhw87gwcPdqpWreoULVrUqVatmtO9e3dnw4YNrsfk9SPQ7r62du3a5eE7F7hsn/ObY8rqX8axBxPb53zy5MlOs2bNnHLlyjlFihRx7rjjDmfAgAHOf/7zn7x82wKa7XPO73lmts95VqQALWkS4jgZrn8CAAAgIAVMTx0AAABujaIOAADAAhR1AAAAFqCoAwAAsABFHQAAgAUo6gAAACxAUQcAAGCBHO8okXF/NgSOvCxDyJwHJuY8+DDnwYc5Dz45mXOu1AEAAFiAog4AAMACFHUAAAAWoKgDAACwAEUdAACABSjqAAAALEBRBwAAYAGKOgAAAAtQ1AEAAFiAog4AAMACFHUAAAAWoKgDAACwAEUdAACABSjqAAAALEBRBwAAYIEi/h4A4Avz589XefTo0SrHx8er3L17d5V/+ukn3wwMAGClzz77TOWQkBCVH374YZ+PgSt1AAAAFqCoAwAAsABFHQAAgAXoqctCqVKlVC5ZsqTK3bp1U7lSpUoqz507V+Vr1655cXTISo0aNVQeOHCgyunp6So3aNBA5fr166tMT13BV7duXZWLFi2qctu2bVVesmSJyubPRF5t2bLFdXvAgAHq2PXr1716LvzOnPOIiAiVp02bpvIDDzzg8zEheLz99tsqmz9/0dHR+TkcEeFKHQAAgBUo6gAAACxAUQcAAGCBoOypM/uvJkyYoHLr1q1Vvueeezx6/dtvv11lc400eF9KSorKX3zxhco9evTIz+HAC/7whz+oPHToUJX79u2rcqFC+v9R77jjDpXNHjrHcfI4Qi3jz9iyZcvUsT/+8Y8qp6amevXcwapMmTIqx8bGqvzLL7+oXLVqVbfHgezMmDHDdft//ud/1LEbN26obK5blx+4UgcAAGABijoAAAALWPn2q7k8hfnWx1NPPaVyWFiYyubWHseOHVP54sWLKpvLY/Tr10/ljEspJCQk3GLUyItff/1VZZYkCXzTp09XuWvXrn4aiecGDx6s8nvvvafyV199lZ/DCVrm2628/Yq8atWqleu2uaTOrl27VF6/fn2+jCkjrtQBAABYgKIOAADAAhR1AAAAFgjInjrzY+wzZ85UuX///iqb235lJzExUeXIyEiVzffRzT65ihUrus3wvrJly6p83333+Wcg8Jrt27ernF1P3alTp1Q2+9jMJU+y2ybM3PKnXbt2bh+Pgsfsj0bgM7f/e/XVV1V+4oknVD579myezme+XsYlzg4fPqyORUVF5elc3sCVOgAAAAtQ1AEAAFiAog4AAMACAdlT16tXL5WHDx+ep9cz3xfv1KmTyuY6dbVr187T+eB9xYsXV7l69eoePb958+Yqm32SrHuX/5YuXary5s2b3T7e3KInr2uQlS5dWuX4+HiVzW3IMjLHum/fvjyNBbljbgUXGhrqp5HAW9555x2V69Spo3LDhg1VNteO89Qrr7yicoUKFVy3n332WXXswIEDeTqXN3ClDgAAwAIUdQAAABagqAMAALBAQPbU9e3b16PHJyUlqbx3716VJ0yYoLLZQ2cy93qF//38888qr1q1SuUpU6a4fb55/Pz58yovWrQolyNDbqWlpamc3e+lt5nrU5YrVy7Hz01OTlb52rVrXhkT8qZZs2Yq/+tf//LTSJBbly9fVtnbfZONGzdWOTw8XOWM61sWxB5NrtQBAABYgKIOAADAAhR1AAAAFgjInjpzbZjnnntO5U8//VTlH374QWVzj0hPValSJU/Ph++98cYbKmfXUwcMGDBAZfPvTFhYWI5fa9KkSV4ZE9wz+y4vXLigsrlP+N133+3zMcG7zL/ljRo1UvngwYMqe7pWXIkSJVQ2e+zNNVAz9mFu2LDBo3PlB67UAQAAWICiDgAAwAIUdQAAABYIyJ46c02y/O6Xat26db6eD3lXqJD+/5eMaw0hODz11FMqv/zyyyqbezoXLVrUo9ffv3+/67a5Dy18w1xP8ssvv1S5e/fu+TgaeMNdd92lstnbavZRjhw5UuWUlBSPzjd37lyVzXVwzXrjgQce8Oj18xtX6gAAACxAUQcAAGABijoAAAALBGRPXV6NHj1aZXOdmuyY6+SY/vnPf6q8e/duj14f3mf20Jn7BaLgqVGjhsqDBg1SuWPHjh69Xps2bVT29GcgNTVVZbMnb9u2ba7bV65c8ei1gWB1zz33qBwTE6NyxYoVVV64cKHKn3/+uUfni4qKUnno0KFuHz916lSPXt/fuFIHAABgAYo6AAAAC1DUAQAAWMCKnjpzb7aGDRuqPHnyZJW7du3q9vU8XdPMXMfm6aefVvm3335z+3wAmXtrPvjgA5WrV6+en8PJxFwD7Z133vHTSJBbFSpU8PcQgk6RIrrMGDhwoMrvvfeeytn999dcJ3bixIkqm+vOlS9fXmVzHbqQkBCVo6OjVV6+fLkEEq7UAQAAWICiDgAAwAIUdQAAABYIiJ46cw/GJk2aqLxx40aVb7/9dpXNNaPMHjhzHbnOnTurbPbsmcyegccff1zl+fPnu25fv37d7WsB+J3Z62JmT+V1/19zH9EuXbqo/NFHH+VuYMg3PXr08PcQgs6AAQNUXrFihcrmepHm7+UPP/ygcrNmzdzmxx57TOVq1aqpbNYH5l6xzzzzjAQyrtQBAABYgKIOAADAAhR1AAAAFiiQPXW33XabymaP26ZNm9w+/09/+pPKO3bsUPmrr75S2VzHxny8uX6WqVKlSipPnz5d5aNHj7pub968WR27du2a29eGd3jaT9W2bVuVFy1a5PUxQYuPj1e5ffv2KpvrW33yyScqX716NU/nHzZsmMqjRo3K0+sh/8XGxqps9kEif/Tv3991e+XKlerYjRs3VD5//rzKTz75pMrnzp1Tec6cOSq3a9dOZbPHzuzFNXv4zL1ljx07prL5d+jw4cNSkHGlDgAAwAIUdQAAABagqAMAALBAiGO+wXyrB+ZxjajsZFyL7s9//rM6Nn78eLfPNdeHGjRokMrme/ZmD9y2bdtUvv/++1U215Z76623VDZ77sx1cjL6xz/+ofLMmTNVNvsHTPv373d73JTD6c2Sr+c8P5n773r6fbn33ntV/u677/I8Jl9hznOnTJkyKp85c8bt4x999FGV/blOHXP+u969e6v8v//7vyqba5aa+4T/9NNPvhmYDxTkOc/Ylx4eHq6OvfnmmyqbPXfZMefM3JvV3Bs2u54601//+leVBw8e7NH4fCknc86VOgAAAAtQ1AEAAFiAog4AAMACflunrnDhwiq/8cYbrttRUVHq2K+//qryyy+/rPK6detUNnvozHVrzDXHzL1kExMTVR4xYoTK5lpIpUuXVjkiIkLlp556ynXb3Htw+/bt4o65Zk7NmjXdPh5ZW7ZsmcrPP/+8R89/7rnnVP7jH/+Y1yGhgImMjPT3EJBHaWlpbo+b/VXFihXz5XCC1pYtW1y3zXVlzf+mecpcVy67dWSfeOIJlc31ME3Jycm5G1gBwZU6AAAAC1DUAQAAWICiDgAAwAJ+66kze5Qy9tFdvnxZHTP7nz799FOVW7VqpfLTTz+tcpcuXVQOCwtT2VwXz1w3J7segNTUVJU//vjjW2bz/X1znzvT2LFj3R5HziQkJPh7CBC9HuUjjzyijpl7Lptrinmb+Xdi/vz5Pj0ffC9jL5dI5t/7+vXrq2z2xr7wwgs+GVew8ebvkrl+ZN++fVU2e9rNvVnXr1/vtbEEAq7UAQAAWICiDgAAwAIUdQAAABbw296vJ06cUDnjfqzXrl1Tx8y+iBIlSqhcu3Ztj849ZcoUladPn66yuU9oICvI+wP60/fff6/y3Xff7fbxhQrp//8xf+bMPg5/Kkhz3qZNG5VfffVV1+1OnTqpY+YajHldz6p8+fIqd+3aVeWFCxeqXKpUKbevZ/b4mWtOmutX5qeCNOcFybx581Q2+yirVKmi8tWrV309JK8JljmfOHGiyhnXtBURSUlJUbl58+YqB/q6cxmx9ysAAECQoKgDAACwgN+WNPnll19Uzvj2q7l1y3333ef2tbZt26byF198ofLmzZtVTkpKUtmmt1uRM99++63KtWrVcvv49PR0Xw7HWuaWfO629HnppZdUvnjxYp7Obb69e//996uc3VsZO3fuVHnp0qUq+/PtVuSOOefXr1/300hwK+Hh4SoPHz5cZXMO33nnHZVters1N7hSBwAAYAGKOgAAAAtQ1AEAAFjAbz11bdu2Vblnz56u22bvy6lTp1R+//33VT537pzK9EkgO2YfxqOPPuqnkeCmESNG5Ov5zL8rH374ocpjxoxROZCWu0DWzC2lHnvsMZVjYmLyczjIwvbt21U2e+zWrFmj8uTJk30+pkDClToAAAALUNQBAABYgKIOAADAAn7bJgz5I1i2kvGU2aexdetWlRs0aKCy+b2oW7euymwTlrXGjRurPGrUKNftIUOGePVc5hxcvnxZ5S+//FJls68yPj7eq+PJTwVpzguSn3/+WeVy5cqp3KRJE5XNLSkLMlvnPLttwfr27atyMPVBsk0YAABAkKCoAwAAsABFHQAAgAXoqbOcrX0XuLWCPOcZ93UeOnSoOvbmm2+qbPY/mXs4m+tZbdmyRWVzf2mbFeQ596d169apbPbK9ujRQ+WffvrJ52PyFuY8+NBTBwAAECQo6gAAACxAUQcAAGABeuosR99F8GHOgw9zHnyY8+BDTx0AAECQoKgDAACwAEUdAACABSjqAAAALEBRBwAAYAGKOgAAAAtQ1AEAAFiAog4AAMACFHUAAAAWoKgDAACwAEUdAACABXK89ysAAAAKLq7UAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWOD/AN0W3QAP3VaMAAAAAElFTkSuQmCC" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Get the first 10 samples\n", "dataiter = iter(trainset)\n", "images, labels = [], []\n", "\n", "for i in range(10):\n", " image, label = next(dataiter)\n", " images.append(image)\n", " labels.append(label)\n", "\n", "# Plot the samples\n", "fig, axes = plt.subplots(2, 5)\n", "\n", "for ax, img, lbl in zip(axes.ravel(), images, labels):\n", " ax.imshow(img.squeeze().numpy(), cmap='gray')\n", " ax.set_title(f'Label: {lbl}')\n", " ax.axis('off')\n", "\n", "plt.tight_layout()\n", "plt.show()" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:40:28.606350630Z", "start_time": "2023-10-03T12:40:28.277928820Z" } }, "id": "87c6eae807f51118" }, { "cell_type": "markdown", "source": [ "Define the MLP and Convolutional Autoencoder" ], "metadata": { "collapsed": false }, "id": "e4e25962ef8e5b0d" }, { "cell_type": "code", "execution_count": 5, "outputs": [], "source": [ "class Autoencoder(nn.Module):\n", " def __init__(self, input_size, hidden_size, type='mlp'):\n", " super(Autoencoder, self).__init__()\n", " # type of autoencoder\n", " self.type = type\n", " if self.type == 'mlp':\n", " self.encoder = nn.Sequential(\n", " nn.Linear(input_size, hidden_size),\n", " nn.ReLU(True))\n", " self.decoder = nn.Sequential(\n", " nn.Linear(hidden_size, input_size),\n", " nn.ReLU(True),\n", " nn.Sigmoid()\n", " )\n", " elif self.type == 'cnn':\n", " # Encoder module\n", " self.encoder = nn.Sequential(\n", " nn.Conv2d(in_channels=input_size, out_channels=hidden_size//2, kernel_size=3, stride=2, padding=1),\n", " nn.ReLU(),\n", " nn.Conv2d(in_channels=hidden_size//2, out_channels=hidden_size, kernel_size=3, stride=2, padding=1),\n", " nn.ReLU()\n", " )\n", " # Decoder module\n", " self.decoder = nn.Sequential(\n", " nn.ConvTranspose2d(in_channels=hidden_size, out_channels=hidden_size//2, kernel_size=3, stride=2, padding=1, output_padding=1),\n", " nn.ReLU(),\n", " nn.ConvTranspose2d(in_channels=hidden_size//2, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1),\n", " nn.Sigmoid() # Sigmoid to ensure the output is between 0 and 1\n", " )\n", "\n", " def forward(self, x):\n", " x = self.encoder(x)\n", " x = self.decoder(x)\n", " return x" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:40:29.561602021Z", "start_time": "2023-10-03T12:40:29.559204154Z" } }, "id": "26f2513d92b78e1e" }, { "cell_type": "markdown", "source": [ "Check if GPU support is available" ], "metadata": { "collapsed": false }, "id": "91a01313b4d95274" }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "# device\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:40:30.834696308Z", "start_time": "2023-10-03T12:40:30.827970016Z" } }, "id": "67006b35b75d8dff" }, { "cell_type": "markdown", "source": [ "Define the training function" ], "metadata": { "collapsed": false }, "id": "8eebf70cb27640d5" }, { "cell_type": "code", "execution_count": 7, "outputs": [], "source": [ "# Define the training function\n", "def train(model, train_loader, optimizer, criterion, epoch, verbose=True):\n", " model.train()\n", " train_loss = 0\n", " for i, (data, _) in enumerate(train_loader):\n", " # check the type of autoencoder and modify the input data accordingly\n", " if model.type == 'mlp':\n", " data = data.view(data.size(0), -1)\n", " data = data.to(device)\n", " optimizer.zero_grad()\n", " output = model(data)\n", " loss = criterion(output, data)\n", " loss.backward()\n", " train_loss += loss.item()\n", " optimizer.step()\n", " train_loss /= len(train_loader.dataset)\n", " if verbose:\n", " print(f'{model.type}====> Epoch: {epoch} Average loss: {train_loss:.4f}') \n", " return train_loss" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:40:31.675127463Z", "start_time": "2023-10-03T12:40:31.655370005Z" } }, "id": "5f96f7be13984747" }, { "cell_type": "markdown", "source": [ "The evaluation functions for the linear classification and clustering tasks" ], "metadata": { "collapsed": false }, "id": "5f6386edcab6b1e4" }, { "cell_type": "code", "execution_count": 8, "outputs": [], "source": [ "# Extract encoded representations for a given loader\n", "def extract_features(loader, model):\n", " features = []\n", " labels = []\n", " model.eval()\n", " with torch.no_grad():\n", " for data in loader:\n", " img, label = data\n", " if model.type == 'mlp':\n", " img = img.view(img.size(0), -1)\n", " img = img.to(device)\n", " feature = model.encoder(img)\n", " if model.type == 'cnn':\n", " feature = feature.view(feature.size(0), -1) # Flatten the CNN encoded features\n", " features.append(feature)\n", " labels.append(label)\n", " return torch.cat(features), torch.cat(labels)\n", "\n", "# Define the loss test function\n", "def test_loss(model, test_loader, criterion):\n", " model.eval()\n", " eval_loss = 0\n", " with torch.no_grad():\n", " for i, (data, _) in enumerate(test_loader):\n", " # check the type of autoencoder and modify the input data accordingly\n", " if model.type == 'mlp':\n", " data = data.view(data.size(0), -1)\n", " data = data.to(device)\n", " output = model(data)\n", " eval_loss += criterion(output, data).item()\n", " eval_loss /= len(test_loader.dataset)\n", " print('====> Test set loss: {:.4f}'.format(eval_loss))\n", " return eval_loss\n", "\n", "# Define the linear classification test function\n", "def test_linear(encoded_train, train_labels, encoded_test, test_labels):\n", " train_features_np = encoded_train.cpu().numpy()\n", " train_labels_np = train_labels.cpu().numpy()\n", " test_features_np = encoded_test.cpu().numpy()\n", " test_labels_np = test_labels.cpu().numpy()\n", " \n", " # Apply logistic regression on train features and labels\n", " logistic_regression = LogisticRegression(random_state=0, max_iter=100).fit(train_features_np, train_labels_np)\n", " print(f\"Train accuracy: {logistic_regression.score(train_features_np, train_labels_np)}\")\n", " # Apply logistic regression on test features and labels\n", " test_accuracy = logistic_regression.score(test_features_np, test_labels_np)\n", " print(f\"Test accuracy: {test_accuracy}\")\n", " return test_accuracy\n", "\n", "\n", "def test_clustering(encoded_features, true_labels):\n", " encoded_features_np = encoded_features.cpu().numpy()\n", " true_labels_np = true_labels.cpu().numpy()\n", " \n", " # Apply k-means clustering\n", " kmeans = KMeans(n_clusters=10, n_init=10, random_state=0).fit(encoded_features_np)\n", " cluster_labels = kmeans.labels_\n", " \n", " # Evaluate clustering results using Adjusted Rand Index\n", " ari_score = adjusted_rand_score(true_labels_np, cluster_labels)\n", " print(f\"Clustering ARI score: {ari_score}\")\n", " return ari_score\n", "\n", "def knn_classifier(encoded_train, train_labels, encoded_test, test_labels, k=5):\n", " encoded_train_np = encoded_train.cpu().numpy()\n", " encoded_test_np = encoded_test.cpu().numpy()\n", " train_labels_np = train_labels.cpu().numpy()\n", " test_labels_np = test_labels.cpu().numpy()\n", " \n", " # Apply k-nearest neighbors classification\n", " knn = KNeighborsClassifier(n_neighbors=k).fit(encoded_train_np, train_labels_np)\n", " accuracy_score = knn.score(encoded_test_np, test_labels_np)\n", " print(f\"KNN accuracy: {accuracy_score}\")\n", " return accuracy_score\n", "\n", "def test(model, train_loader, test_loader, criterion):\n", " # Extract features once for all tests\n", " encoded_train, train_labels = extract_features(train_loader, model)\n", " encoded_test, test_labels = extract_features(test_loader, model)\n", " print(f\"{model.type} Autoencoder\")\n", " results = {\n", " 'reconstruction_loss': test_loss(model, test_loader, criterion),\n", " 'linear_classification_accuracy': test_linear(encoded_train, train_labels, encoded_test, test_labels),\n", " 'knn_classification_accuracy': knn_classifier(encoded_train, train_labels, encoded_test, test_labels),\n", " 'clustering_ari_score': test_clustering(encoded_test, test_labels)\n", " }\n", " \n", " # Save results to a log file\n", " with open(\"evaluation_results.log\", \"w\") as log_file:\n", " for key, value in results.items():\n", " log_file.write(f\"{key}: {value}\")\n", " \n", " return results\n" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:40:32.662180572Z", "start_time": "2023-10-03T12:40:32.657785583Z" } }, "id": "b2c4483492fdd427" }, { "cell_type": "code", "execution_count": 9, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MLP AE parameters: 201616\n", "CNN AE parameters: 148865\n" ] } ], "source": [ "# Define the training parameters for the fully connected MLP Autoencoder\n", "batch_size = 32\n", "epochs = 5\n", "hidden_size = 128\n", "train_frequency = epochs\n", "test_frequency = epochs\n", "\n", "# Create the fully connected MLP Autoencoder\n", "input_size = trainset.data.shape[1] * trainset.data.shape[2]\n", "ae = Autoencoder(input_size, hidden_size, type='mlp').to(device)\n", "input_size=1\n", "cnn_ae = Autoencoder(input_size, hidden_size, type='cnn').to(device)\n", "# print the models' number of parameters\n", "print(f\"MLP AE parameters: {sum(p.numel() for p in ae.parameters())}\")\n", "print(f\"CNN AE parameters: {sum(p.numel() for p in cnn_ae.parameters())}\")\n", "\n", "# Define the loss function and optimizer\n", "criterion = nn.MSELoss()\n", "optimizer = optim.Adam(ae.parameters(), lr=1e-3)\n", "optimizer_cnn = optim.Adam(cnn_ae.parameters(), lr=1e-3)\n" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:40:33.593858793Z", "start_time": "2023-10-03T12:40:33.153759806Z" } }, "id": "bcb22bc5af9fb014" }, { "cell_type": "code", "execution_count": 10, "outputs": [], "source": [ "# Create the train and test dataloaders\n", "train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)\n", "test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:40:34.157176359Z", "start_time": "2023-10-03T12:40:34.153720678Z" } }, "id": "f1626ce45bb25883" }, { "cell_type": "code", "execution_count": 11, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 80%|████████ | 4/5 [01:02<00:15, 15.79s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "mlp====> Epoch: 5 Average loss: 0.0598\n", "cnn====> Epoch: 5 Average loss: 0.0260\n", "mlp Autoencoder\n", "====> Test set loss: 0.0598\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train accuracy: 0.25626666666666664\n", "Test accuracy: 0.2649\n", "KNN accuracy: 0.2295\n", "Clustering ARI score: 0.0614873771495409\n", "cnn Autoencoder\n", "====> Test set loss: 0.0260\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train accuracy: 0.9316166666666666\n", "Test accuracy: 0.9278\n", "KNN accuracy: 0.9639\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 5/5 [06:56<00:00, 83.22s/it] " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Clustering ARI score: 0.3909294873624941\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "test_mlp = []\n", "test_cnn = []\n", "# Train the model\n", "for epoch in tqdm(range(1, epochs + 1)):\n", " verbose = True if epoch % train_frequency == 0 else False\n", " train(ae, train_loader, optimizer, criterion, epoch, verbose)\n", " train(cnn_ae, train_loader, optimizer_cnn, criterion, epoch, verbose)\n", "\n", " # test every n epochs\n", " if epoch % test_frequency == 0:\n", " restults_dic = test(ae, train_loader, test_loader, criterion)\n", " test_mlp.append([restults_dic['reconstruction_loss'], restults_dic['linear_classification_accuracy'], restults_dic['knn_classification_accuracy'], restults_dic['clustering_ari_score']])\n", " restults_dic = test(cnn_ae, train_loader, test_loader, criterion)\n", " test_cnn.append([restults_dic['reconstruction_loss'], restults_dic['linear_classification_accuracy'], restults_dic['knn_classification_accuracy'], restults_dic['clustering_ari_score']])" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:47:30.814501313Z", "start_time": "2023-10-03T12:40:34.720164326Z" } }, "id": "7472159fdc5f2532" }, { "cell_type": "markdown", "source": [ "Compare the evaluation results of the MLP and CNN Autoencoders" ], "metadata": { "collapsed": false }, "id": "10639256e342a159" }, { "cell_type": "code", "execution_count": 12, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model Reconstruction Loss Linear Accuracy KNN Accuracy Clustering ARI \n", "MLP AE 0.0598 0.2649 0.2295 0.0615 \n", "CNN AE 0.0260 0.9278 0.9639 0.3909 \n" ] } ], "source": [ "print(f\"{'Model':<10} {'Reconstruction Loss':<20} {'Linear Accuracy':<20} {'KNN Accuracy':<20} {'Clustering ARI':<20}\")\n", "print(f\"{'MLP AE':<10} {test_mlp[-1][0]:<20.4f} {test_mlp[-1][1]:<20.4f} {test_mlp[-1][2]:<20.4f} {test_mlp[-1][3]:<20.4f}\")\n", "print(f\"{'CNN AE':<10} {test_cnn[-1][0]:<20.4f} {test_cnn[-1][1]:<20.4f} {test_cnn[-1][2]:<20.4f} {test_cnn[-1][3]:<20.4f}\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:47:30.828062767Z", "start_time": "2023-10-03T12:47:30.812448850Z" } }, "id": "50bb4c3c58af09ee" }, { "cell_type": "markdown", "source": [ "Develop a linear classifier with fully connected layers" ], "metadata": { "collapsed": false }, "id": "b9201d1403781706" }, { "cell_type": "code", "execution_count": 13, "outputs": [], "source": [ "# Define the fully connected classifier for MNIST\n", "class DenseClassifier(nn.Module):\n", " def __init__(self, input_size=784, hidden_size=500, num_classes=10):\n", " super(DenseClassifier, self).__init__()\n", " self.type = 'mlp'\n", " self.encoder = nn.Sequential(\n", " nn.Linear(input_size, hidden_size),\n", " nn.ReLU(True))\n", " self.fc1 = nn.Linear(hidden_size, num_classes)\n", "\n", " def forward(self, x):\n", " x = x.view(x.size(0), -1) # Flatten the input tensor\n", " x = self.encoder(x)\n", " x = self.fc1(x)\n", " return x\n" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:47:30.833296890Z", "start_time": "2023-10-03T12:47:30.819270525Z" } }, "id": "1612800950703181" }, { "cell_type": "markdown", "source": [ "Develop a non-linear classifier with convolutional layers" ], "metadata": { "collapsed": false }, "id": "35db4190e9c7f716" }, { "cell_type": "code", "execution_count": 14, "outputs": [], "source": [ "# cnn classifier\n", "class CNNClassifier(nn.Module):\n", " def __init__(self, input_size=3, hidden_size=128, num_classes=10):\n", " super(CNNClassifier, self).__init__()\n", " self.type = 'cnn'\n", " # Encoder (Feature extractor)\n", " self.encoder = nn.Sequential(\n", " nn.Conv2d(in_channels=input_size, out_channels=hidden_size//2, kernel_size=3, stride=2, padding=1),\n", " nn.ReLU(),\n", " nn.Conv2d(in_channels=hidden_size//2, out_channels=hidden_size, kernel_size=3, stride=2, padding=1),\n", " nn.ReLU()\n", " )\n", " \n", " # Classifier\n", " # Here, for the sake of example, I'm assuming the spatial size of the encoder output \n", " # is 7x7 for an input size of 28x28. You might want to adjust this if the spatial dimensions change.\n", " self.classifier = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(hidden_size*7*7, hidden_size),\n", " nn.ReLU(),\n", " nn.Linear(hidden_size, num_classes),\n", " nn.LogSoftmax(dim=1) # LogSoftmax is typically used with NLLLoss\n", " )\n", "\n", " def forward(self, x):\n", " x = self.encoder(x)\n", " x = self.classifier(x)\n", " return x\n", " return x" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:47:30.916320521Z", "start_time": "2023-10-03T12:47:30.828281294Z" } }, "id": "cb2dfcf75113fd0b" }, { "cell_type": "markdown", "source": [ "Train and test functions for the non-linear classifier" ], "metadata": { "collapsed": false }, "id": "7c3cf2371479da0c" }, { "cell_type": "code", "execution_count": 15, "outputs": [], "source": [ "# Train for the classifier\n", "def train_classifier(model, train_loader, optimizer, criterion, epoch, verbose=True):\n", " model.train()\n", " train_loss = 0\n", " correct = 0 \n", " for i, (data, target) in enumerate(train_loader):\n", " if model.type == 'cnn':\n", " data = data.to(device)\n", " else:\n", " data = data.view(data.size(0), -1)\n", " data = data.to(device)\n", " target = target.to(device)\n", " optimizer.zero_grad()\n", " output = model(data)\n", " loss = criterion(output, target)\n", " loss.backward()\n", " train_loss += loss.item()\n", " optimizer.step()\n", " # Calculate correct predictions for training accuracy\n", " pred = output.argmax(dim=1, keepdim=True)\n", " correct += pred.eq(target.view_as(pred)).sum().item()\n", "\n", " train_loss /= len(train_loader.dataset)\n", " train_accuracy = 100. * correct / len(train_loader.dataset)\n", " if verbose:\n", " print(f'{model.type}====> Epoch: {epoch} Average loss: {train_loss:.4f}')\n", " print(f'{model.type}====> Epoch: {epoch} Training accuracy: {train_accuracy:.2f}%')\n", " return train_loss\n", "\n", "\n", "def test_classifier(model, test_loader, criterion):\n", " model.eval()\n", " eval_loss = 0\n", " correct = 0\n", " with torch.no_grad():\n", " for i, (data, target) in enumerate(test_loader):\n", " if model.type == 'cnn':\n", " data = data.to(device)\n", " else:\n", " data = data.view(data.size(0), -1)\n", " data = data.to(device)\n", " target = target.to(device)\n", " output = model(data)\n", " eval_loss += criterion(output, target).item()\n", " pred = output.argmax(dim=1, keepdim=True)\n", " correct += pred.eq(target.view_as(pred)).sum().item()\n", " eval_loss /= len(test_loader.dataset)\n", " print('====> Test set loss: {:.4f}'.format(eval_loss))\n", " accuracy = correct / len(test_loader.dataset)\n", " print('====> Test set accuracy: {:.4f}'.format(accuracy))\n", " return accuracy\n" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:47:30.924063759Z", "start_time": "2023-10-03T12:47:30.875486950Z" } }, "id": "ac980d25bd8a3dd3" }, { "cell_type": "code", "execution_count": 16, "outputs": [], "source": [ "# Define the training parameters for the fully connected classifier\n", "batch_size = 32\n", "epochs = 5\n", "learning_rate = 1e-3\n", "hidden_size = 128\n", "num_classes = 10\n", "train_frequency = epochs\n", "test_frequency = epochs\n", "# Create the fully connected classifier\n", "input_size = trainset.data.shape[1] * trainset.data.shape[2]\n", "classifier = DenseClassifier(input_size, hidden_size, num_classes).to(device)\n", "input_size = 1\n", "cnn_classifier = CNNClassifier(input_size, hidden_size, num_classes).to(device)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:47:30.924544974Z", "start_time": "2023-10-03T12:47:30.875705556Z" } }, "id": "dff05e622dcfd774" }, { "cell_type": "code", "execution_count": 17, "outputs": [], "source": [ "# Define the loss function and optimizer\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(classifier.parameters(), lr=learning_rate)\n", "optimizer_cnn = optim.Adam(cnn_classifier.parameters(), lr=learning_rate)\n" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:47:30.924773429Z", "start_time": "2023-10-03T12:47:30.875787620Z" } }, "id": "3104345cdee0eb00" }, { "cell_type": "markdown", "source": [ "Train the non-linear classifiers" ], "metadata": { "collapsed": false }, "id": "e1fed39be2f04745" }, { "cell_type": "code", "execution_count": 18, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 80%|████████ | 4/5 [02:10<00:32, 32.47s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "mlp====> Epoch: 5 Average loss: 0.0030\n", "mlp====> Epoch: 5 Training accuracy: 97.08%\n", "cnn====> Epoch: 5 Average loss: 0.0005\n", "cnn====> Epoch: 5 Training accuracy: 99.49%\n", "====> Test set loss: 0.0035\n", "====> Test set accuracy: 0.9686\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 5/5 [02:48<00:00, 33.62s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "====> Test set loss: 0.0013\n", "====> Test set accuracy: 0.9880\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Train the classifier\n", "for epoch in tqdm(range(1, epochs + 1)):\n", " verbose = True if epoch % train_frequency == 0 else False\n", " train_classifier(classifier, train_loader, optimizer, criterion, epoch, verbose)\n", " train_classifier(cnn_classifier, train_loader, optimizer_cnn, criterion, epoch, verbose)\n", "\n", " # test every n epochs\n", " if epoch % test_frequency == 0:\n", " test_acc = test_classifier(classifier, test_loader, criterion)\n", " test_acc_cnn = test_classifier(cnn_classifier, test_loader, criterion)\n" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:50:19.005163249Z", "start_time": "2023-10-03T12:47:30.875867176Z" } }, "id": "abc0c6ce338d40d9" }, { "cell_type": "markdown", "source": [ "Load the encoder weights into the classifier" ], "metadata": { "collapsed": false }, "id": "a06038f113d8434f" }, { "cell_type": "code", "execution_count": 19, "outputs": [ { "data": { "text/plain": "" }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# initialize the classifier with the encoder weights\n", "classifier.encoder.load_state_dict(ae.encoder.state_dict())\n", "cnn_classifier.encoder.load_state_dict(cnn_ae.encoder.state_dict())" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:50:19.005816667Z", "start_time": "2023-10-03T12:50:18.994175691Z" } }, "id": "6a91d8894b70ef7c" }, { "cell_type": "markdown", "source": [ "Transfer learning" ], "metadata": { "collapsed": false }, "id": "aafa4a9ba7208647" }, { "cell_type": "code", "execution_count": 20, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 80%|████████ | 4/5 [01:56<00:27, 27.00s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "mlp====> Epoch: 5 Average loss: 0.0547\n", "mlp====> Epoch: 5 Training accuracy: 38.00%\n", "cnn====> Epoch: 5 Average loss: 0.0004\n", "cnn====> Epoch: 5 Training accuracy: 99.57%\n", "====> Test set loss: 0.0526\n", "====> Test set accuracy: 0.4150\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 5/5 [02:25<00:00, 29.01s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "====> Test set loss: 0.0017\n", "====> Test set accuracy: 0.9868\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# fine-tune the classifier\n", "learning_rate = 1e-5\n", "epochs = 5\n", "train_frequency = epochs\n", "test_frequency = epochs\n", "optimizer_pretrained = optim.Adam(classifier.parameters(), lr=learning_rate)\n", "optimizer_pretrained_cnn = optim.Adam(cnn_classifier.parameters(), lr=learning_rate)\n", "for epoch in tqdm(range(1, epochs + 1)):\n", " verbose = True if epoch % train_frequency == 0 else False\n", " train_loss = train_classifier(classifier, train_loader, optimizer_pretrained, criterion, epoch, verbose)\n", " train_loss_cnn = train_classifier(cnn_classifier, train_loader, optimizer_cnn, criterion, epoch, verbose)\n", "\n", " # test every n epochs\n", " if epoch % test_frequency == 0:\n", " test_acc_pretrained = test_classifier(classifier, test_loader, criterion)\n", " test_acc_pretrained_cnn = test_classifier(cnn_classifier, test_loader, criterion)\n" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:52:44.085979647Z", "start_time": "2023-10-03T12:50:19.003728990Z" } }, "id": "a60dd68f988a8249" }, { "cell_type": "markdown", "source": [ "Compare the results of the linear probing with the results of the linear classifier" ], "metadata": { "collapsed": false }, "id": "31577275b833707a" }, { "cell_type": "code", "execution_count": 21, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model Linear Accuracy Non-linear accuracy Pretrained accuracy \n", "MLP AE 0.2649 0.9686 0.4150 \n", "CNN AE 0.9278 0.9880 0.9868 \n" ] } ], "source": [ "# print a table of the accuracies. compare the results with the results of the linear probing\n", "print(f\"{'Model':<10} {'Linear Accuracy':<20} {'Non-linear accuracy':<20} {'Pretrained accuracy':<20}\")\n", "print(f\"{'MLP AE':<10} {test_mlp[-1][1]:<20.4f} {test_acc:<20.4f} {test_acc_pretrained:<20.4f}\")\n", "print(f\"{'CNN AE':<10} {test_cnn[-1][1]:<20.4f} {test_acc_cnn:<20.4f} {test_acc_pretrained_cnn:<20.4f}\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2023-10-03T12:52:44.091206610Z", "start_time": "2023-10-03T12:52:44.084572508Z" } }, "id": "40d0e7f3f13404c9" }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [], "metadata": { "collapsed": false }, "id": "f38a1ab6951a694e" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }