AppliedMachineAndDeepLearni.../autoencoder_mnist.ipynb
2023-10-09 10:02:01 +00:00

1136 lines
52 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"source": [
"| Credentials | |\n",
"|----|----------------------------------|\n",
"|Host | Montanuniversitaet Leoben |\n",
"|Web | https://cps.unileoben.ac.at |\n",
"|Mail | cps@unileoben.ac.at |\n",
"|Author | Fotios Lygerakis |\n",
"|Corresponding Authors | fotios.lygerakis@unileoben.ac.at |\n",
"|Last edited | 28.09.2023 |"
],
"metadata": {
"collapsed": false
},
"id": "ae041e151c5c2222"
},
{
"cell_type": "markdown",
"source": [
"In the first part of this tutorial we will build a fully connected MLP Autoencoder on the MNIST dataset. Then we will perform linear probing on the encoder features to see how well they perform on a linear classification task."
],
"metadata": {
"collapsed": false
},
"id": "1490260facaff836"
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import matplotlib.pyplot as plt\n",
"from torchvision import datasets, transforms\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.metrics import adjusted_rand_score\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.cluster import KMeans\n",
"from tqdm import tqdm"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:40:27.590742473Z",
"start_time": "2023-10-03T12:40:25.356175335Z"
}
},
"id": "fc18830bb6f8d534"
},
{
"cell_type": "markdown",
"source": [
"Set random seed"
],
"metadata": {
"collapsed": false
},
"id": "1bad4bd03deb5b7e"
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [
{
"data": {
"text/plain": "<torch._C.Generator at 0x7f10889c50d0>"
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Set random seed\n",
"torch.manual_seed(0)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:40:27.592356147Z",
"start_time": "2023-10-03T12:40:27.568457489Z"
}
},
"id": "27dd48e60ae7dd9e"
},
{
"cell_type": "markdown",
"source": [
"Load the MNIST dataset"
],
"metadata": {
"collapsed": false
},
"id": "cc7f167a33227e94"
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"# Define the transformations\n",
"transform = transforms.Compose([transforms.ToTensor(),\n",
" transforms.Normalize((0.5,), (0.5,))])\n",
"# Download and load the training data\n",
"trainset = datasets.MNIST('data', download=True, train=True, transform=transform)\n",
"# Download and load the test data\n",
"testset = datasets.MNIST('data', download=True, train=False, transform=transform)\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:40:27.639417871Z",
"start_time": "2023-10-03T12:40:27.577605311Z"
}
},
"id": "34248e8bc2678fd3"
},
{
"cell_type": "markdown",
"source": [
"Print some examples from the dataset"
],
"metadata": {
"collapsed": false
},
"id": "928dfac955d0d778"
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"data": {
"text/plain": "<Figure size 640x480 with 10 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnUAAAFiCAYAAACQzC7qAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAxUElEQVR4nO3dd3TUVfrH8ScUSegdFCWA9BUF6REBBQxNBKkqTUH9oZTlEEQswK7SpEgvigJZ2MPyAwLKYmElqLhsgFXYjRKMSIQgQqgBqTHf3x8e5pfnJkwyyUwmc+f9Oodz5jPfmfne5Cbx8TvP3BviOI4jAAAACGiF/D0AAAAA5B1FHQAAgAUo6gAAACxAUQcAAGABijoAAAALUNQBAABYgKIOAADAAhR1AAAAFqCoAwAAsIB1RV1SUpKEhITI7NmzvfaaO3fulJCQENm5c6fXXhPew5wHH+Y8+DDnwYc591yBKOpWrVolISEhsm/fPn8PxSemTJkiISEhmf6Fhob6e2h+Y/uci4gcP35c+vXrJ2XLlpXSpUvLY489Jj/++KO/h+U3wTDnGXXq1ElCQkJk5MiR/h6K39g+54cOHZKxY8dKRESEhIaGSkhIiCQlJfl7WH5l+5yLiKxbt07uv/9+CQ0NlUqVKsmwYcPk9OnT/h6WiIgU8fcAgsnSpUulZMmSrly4cGE/jga+dOnSJXnooYfkwoUL8sorr0jRokXl7bfflnbt2sn+/fulQoUK/h4ifGjTpk2ye/dufw8DPrZ7925ZsGCBNGzYUBo0aCD79+/395DgY0uXLpUXXnhBOnToIHPnzpXk5GSZP3++7Nu3T+Li4vx+sYaiLh/16dNHKlas6O9hIB8sWbJEEhMTZc+ePdK8eXMREenSpYvcc889MmfOHJk2bZqfRwhfuXr1qowbN04mTJggkyZN8vdw4EM9evSQ8+fPS6lSpWT27NkUdZa7fv26vPLKK9K2bVvZvn27hISEiIhIRESEPProo/Luu+/KqFGj/DrGAvH2a05cv35dJk2aJE2bNpUyZcpIiRIl5MEHH5TY2NhbPuftt9+W8PBwCQsLk3bt2kl8fHymxyQkJEifPn2kfPnyEhoaKs2aNZMPPvgg2/FcvnxZEhISPLrk6jiOpKamiuM4OX5OMAvkOd+wYYM0b97cVdCJiNSvX186dOgg69evz/b5wSqQ5/ymt956S9LT0yUqKirHzwlmgTzn5cuXl1KlSmX7OGiBOufx8fFy/vx56d+/v6ugExHp3r27lCxZUtatW5ftuXwtYIq61NRUWbFihbRv315mzpwpU6ZMkZSUFImMjMzy/46io6NlwYIF8uKLL8rEiRMlPj5eHn74YTl58qTrMd9++620atVKDh48KC+//LLMmTNHSpQoIT179pSYmBi349mzZ480aNBAFi1alOOvoVatWlKmTBkpVaqUDBw4UI0FmQXqnKenp8t//vMfadasWaZjLVq0kMOHD8vFixdz9k0IMoE65zcdPXpUZsyYITNnzpSwsDCPvvZgFehzDs8F6pxfu3ZNRCTL3+2wsDD55ptvJD09PQffAR9yCoCVK1c6IuLs3bv3lo9JS0tzrl27pu47d+6cU6VKFeeZZ55x3XfkyBFHRJywsDAnOTnZdX9cXJwjIs7YsWNd93Xo0MFp1KiRc/XqVdd96enpTkREhFOnTh3XfbGxsY6IOLGxsZnumzx5crZf37x585yRI0c6a9eudTZs2OCMGTPGKVKkiFOnTh3nwoUL2T7fRjbPeUpKiiMizp///OdMxxYvXuyIiJOQkOD2NWxk85zf1KdPHyciIsKVRcR58cUXc/RcGwXDnN80a9YsR0ScI0eOePQ829g85ykpKU5ISIgzbNgwdX9CQoIjIo6IOKdPn3b7Gr4WMFfqChcuLLfddpuI/H4l5OzZs5KWlibNmjWTr7/+OtPje/bsKdWqVXPlFi1aSMuWLWXbtm0iInL27FnZsWOH9OvXTy5evCinT5+W06dPy5kzZyQyMlISExPl+PHjtxxP+/btxXEcmTJlSrZjHzNmjCxcuFCefPJJ6d27t8ybN09Wr14tiYmJsmTJEg+/E8EjUOf8ypUrIiJSrFixTMduNtHefAy0QJ1zEZHY2FjZuHGjzJs3z7MvOsgF8pwjdwJ1zitWrCj9+vWT1atXy5w5c+THH3+UL7/8Uvr37y9FixYVEf//bQ+Yok5EZPXq1XLvvfdKaGioVKhQQSpVqiR///vf5cKFC5keW6dOnUz31a1b1/Vx8x9++EEcx5HXX39dKlWqpP5NnjxZREROnTrls6/lySeflKpVq8o//vEPn53DBoE45zcvzd+8VJ/R1atX1WOQWSDOeVpamowePVoGDRqk+iiRM4E458ibQJ3z5cuXS9euXSUqKkruvvtuadu2rTRq1EgeffRRERG1woU/BMynX9esWSNDhw6Vnj17yvjx46Vy5cpSuHBhmT59uhw+fNjj17v5vndUVJRERkZm+ZjatWvnaczZueuuu+Ts2bM+PUcgC9Q5L1++vBQrVkxOnDiR6djN++644448n8dGgTrn0dHRcujQIVm+fHmmdcouXrwoSUlJUrlyZSlevHiez2WbQJ1z5F4gz3mZMmVky5YtcvToUUlKSpLw8HAJDw+XiIgIqVSpkpQtW9Yr58mtgCnqNmzYILVq1ZJNmzapT53crMJNiYmJme77/vvvpUaNGiLy+4cWRESKFi0qHTt29P6As+E4jiQlJUmTJk3y/dyBIlDnvFChQtKoUaMsF9+Mi4uTWrVq8Ym5WwjUOT969KjcuHFDHnjggUzHoqOjJTo6WmJiYqRnz54+G0OgCtQ5R+7ZMOfVq1eX6tWri4jI+fPn5d///rf07t07X87tTsC8/XpzoV4nw3IgcXFxt1zgc/Pmzeo99D179khcXJx06dJFREQqV64s7du3l+XLl2d5RSUlJcXteDz52HtWr7V06VJJSUmRzp07Z/v8YBXIc96nTx/Zu3evKuwOHTokO3bskL59+2b7/GAVqHM+YMAAiYmJyfRPRKRr164SExMjLVu2dPsawSpQ5xy5Z9ucT5w4UdLS0mTs2LG5er43Fagrde+//758/PHHme4fM2aMdO/eXTZt2iS9evWSbt26yZEjR2TZsmXSsGFDuXTpUqbn1K5dW9q0aSMjRoyQa9euybx586RChQry0ksvuR6zePFiadOmjTRq1EieffZZqVWrlpw8eVJ2794tycnJcuDAgVuOdc+ePfLQQw/J5MmTs22uDA8Pl/79+0ujRo0kNDRUdu3aJevWrZPGjRvL888/n/NvkIVsnfMXXnhB3n33XenWrZtERUVJ0aJFZe7cuVKlShUZN25czr9BFrJxzuvXry/169fP8ljNmjWD/gqdjXMuInLhwgVZuHChiIh89dVXIiKyaNEiKVu2rJQtWzaot4izdc5nzJgh8fHx0rJlSylSpIhs3rxZPv30U3nzzTcLRj9t/n/gNrObH4G+1b9jx4456enpzrRp05zw8HCnWLFiTpMmTZytW7c6Q4YMccLDw12vdfMj0LNmzXLmzJnj3HXXXU6xYsWcBx980Dlw4ECmcx8+fNgZPHiwU7VqVado0aJOtWrVnO7duzsbNmxwPSavH3sfPny407BhQ6dUqVJO0aJFndq1azsTJkxwUlNT8/JtC2i2z7njOM6xY8ecPn36OKVLl3ZKlizpdO/e3UlMTMzttyzgBcOcm4QlTaye85tjyupfxrEHE9vnfOvWrU6LFi2cUqVKOcWLF3datWrlrF+/Pi/fMq8KcRy2NwAAAAh0AdNTBwAAgFujqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWCDHO0pk3J8NgSMvyxAy54GJOQ8+zHnwYc6DT07mnCt1AAAAFqCoAwAAsABFHQAAgAUo6gAAACxAUQcAAGABijoAAAALUNQBAABYgKIOAADAAhR1AAAAFqCoAwAAsABFHQAAgAUo6gAAACxAUQcAAGABijoAAAALUNQBAABYoIi/BwAUBE2bNlV55MiRKg8ePFjl6OholRcuXKjy119/7cXRAQCQPa7UAQAAWICiDgAAwAIUdQAAABYIcRzHydEDQ0J8PRafKVy4sMplypTx6Plmf1Xx4sVVrlevnsovvviiyrNnz3bdfuKJJ9Sxq1evqjxjxgyV//SnP3k0VlMOpzdLgTzn2WncuLHKO3bsULl06dIevd6FCxdUrlChQq7G5Q3MuX906NDBdXvt2rXqWLt27VQ+dOiQV8/NnPvGa6+9prL597hQof+/LtK+fXt17PPPP/fZuESY82CUkznnSh0AAIAFKOoAAAAsQFEHAABggYBYp6569eoq33bbbSpHRESo3KZNG5XLli2rcu/evb03OBFJTk5WecGCBSr36tXLdfvixYvq2IEDB1T2dR9GsGrRooXKGzduVNnsszR7F8x5u379uspmD12rVq1ct80168zn2qRt27Yqm9+XmJiY/BxOvmrevLnr9t69e/04EuTW0KFDVZ4wYYLK6enpt3xuXnrcAG/hSh0AAIAFKOoAAAAsUCDffs1uuQlPlyTxNvMSvPmx90uXLqmccXmDEydOqGPnzp1T2dtLHQQLc5mZ+++/X+U1a9aofPvtt3v0+omJiSq/9dZbKq9bt07lr776ynXb/PmYPn26R+cOJOayDnXq1FHZprdfMy5nISJSs2ZN1+3w8HB1jCUkAoM5b6GhoX4aCW6lZcuWKg8cOFBlc/mgP/zhD25fLyoqSuWff/5ZZbOdK+N/S+Li4twP1g+4UgcAAGABijoAAAALUNQBAABYoED21B09elTlM2fOqOztnjrzffHz58+r/NBDD6lsLknxl7/8xavjgeeWL1+usrkdW16ZPXolS5ZU2VyKJmNv2b333uvVsRRkgwcPVnn37t1+GonvmX2Zzz77rOu22cOZkJCQL2OCZzp27KjyqFGj3D7enMfu3bu7bp88edJ7A4NL//79VZ4/f77KFStWVNnsX925c6fKlSpVUnnWrFluz2++XsbnDxgwwO1z/YErdQAAABagqAMAALAARR0AAIAFCmRP3dmzZ1UeP368yhn7GEREvvnmG5XNbbpM+/fvV7lTp04q//rrryqb69yMGTPG7evD95o2bapyt27dVM5uXTCzB+7DDz9Uefbs2SqbaxeZP3PmeoMPP/xwjsdiE3PtNputWLHilsfMdQ1RMJhrjq1cuVLl7Pq1zf6rn376yTsDC2JFiugypFmzZiq/++67Kptrkn7xxRcqv/HGGyrv2rVL5WLFiqm8fv16lR955BG34923b5/b4/4WPH+BAQAALEZRBwAAYAGKOgAAAAsUyJ460+bNm1U294K9ePGiyvfdd5/Kw4YNU9nslzJ76Ezffvutys8995zbx8P7zP2At2/frnLp0qVVdhxH5Y8++khlcx07c79Ac79Ws38qJSVF5QMHDqiccX9gs9/PXPPu66+/lkBlrsFXpUoVP40k/7nrvzJ/PlEwDBkyROU77rjD7ePNNc6io6O9PaSgZ+7d6q5XVSTz75a5jl1qaqrb55uPz66HLjk5WeXVq1e7fby/caUOAADAAhR1AAAAFqCoAwAAsEBA9NSZsnvP/MKFC26PZ9yjUUTkb3/7m8oZ+6HgH3Xr1lXZXKvQ7Gc6ffq0yidOnFDZ7IO4dOmSyn//+9/d5rwICwtTedy4cSo/9dRTXjtXfuvatavK5tdqE7NfsGbNmrd87PHjx309HOSAuS/oM888o7L5t97c9/vNN9/0ybiCmbmO3CuvvKKy2Q+9ZMkSlc1+5+zqAdOrr77q0eNHjx6tstlPXdBwpQ4AAMACFHUAAAAWoKgDAACwQED21GVnypQpKpv7hJprknXs2FHlTz/91Cfjwq2Z+/GZawmavVvm2oSDBw9W2dyfryD1elWvXt3fQ/CaevXquT1urvEYyMyfSbPH7vvvv3fdNn8+kT9q1Kih8saNGz16/sKFC1WOjY3N65CC3qRJk1Q2e+iuX7+u8ieffKLyhAkTVL5y5Yrb84WGhqpsrkNn/v019+Y2+yi3bNni9nwFDVfqAAAALEBRBwAAYAGKOgAAAAtY2VNn7uVqrktn7rX57rvvqmz2UZj9WYsXL1bZXFcHnmvSpInKZg+d6bHHHlP5888/9/qYkHd79+719xBuydwvuHPnziqbe1Jmt0dkxvW3zPXOkD/MOTT3JjZ99tlnKs+fP9/rYwpGZcuWdd1+4YUX1DHzv5dmD13Pnj09Olft2rVVXrt2rcpmT71pw4YNKr/11lsenb+g4UodAACABSjqAAAALEBRBwAAYAEre+pMhw8fVnno0KEqr1y5UuVBgwa5zSVKlFA5OjpaZXPfUWRv7ty5KptrB5k9cwW9h65Qof///6Vg3ku4fPnyuX7ufffdp7L5M2GuL3nnnXeqfNttt6ls7rGbcY5EMq9/FRcXp/K1a9dULlJE//n897//LchfZv/VjBkz3D5+165dKg8ZMkTl7PYNR85k/N0z9981mXurVq5cWeWnn35a5R49eqh8zz33qFyyZEmVzR4+M69Zs0Zlsyc/0HClDgAAwAIUdQAAABagqAMAALBAUPTUmWJiYlROTExU2ezv6tChg8rTpk1TOTw8XOWpU6eqfPz48VyN03bdu3d33W7cuLE6ZvY9fPDBB/kxJK/J2Ednfi379+/P59H4jtmHZn6ty5YtU9nc99Edc40xs6cuLS1N5cuXL6v83Xffqfz++++rbK4/afZpnjx5UuXk5GSVzf2EExISBL6V171df/zxR5XNOYZ3ZNzPNSUlRR2rVKmSykeOHFHZ03Vff/75Z5VTU1NVvv3221U+ffq0yh9++KFH5yvouFIHAABgAYo6AAAAC1DUAQAAWCAoe+pM8fHxKvfr10/lRx99VGVzXbvnn39e5Tp16qjcqVOnvA7RShl7ksw1xU6dOqXy3/72t3wZU04VK1ZM5SlTptzysTt27FB54sSJvhiSX5j7Ov70008qR0RE5Pq1jx49qvLmzZtVPnjwoMr/+te/cn2urDz33HMqm71AZn8WfG/ChAkqe7oGZHbr2ME7Mu59bK4luHXrVpXNtSzNdWW3bNmi8qpVq1Q+e/asyuvWrVPZ7Kkzj9uGK3UAAAAWoKgDAACwAEUdAACABeipy0LGfgARkb/85S8qr1ixQmVzD8i2bduq3L59e9ftnTt35nl8wcDcZ9Pf++maPXSvvfaayuPHj1c545pmc+bMUccuXbrk5dEVHDNnzvT3ELzGXJ/S5OkaafCcuX7lI4884tHzzX6sQ4cO5XVI8JC5h7LZm5pX5n9v27Vrp7LZd2l7LyxX6gAAACxAUQcAAGABijoAAAAL0FMnmfeY7NOnj8rNmzdX2eyhM5l7Tn7xxRd5GF1w8vder2Yvj9kz179/f5XN3p3evXv7ZFwoOMw9pOF9n376qcrlypVz+3hzrcKhQ4d6e0goYMw9mM0eOnMvWdapAwAAQIFHUQcAAGABijoAAAALBEVPXb169VQeOXKkyo8//rjKVatW9ej1f/vtN5XNNdU83Z8wWISEhGR5WyTzfoFjxozx6VjGjh2r8uuvv65ymTJlVF67dq3KgwcP9s3AgCBWoUIFlbP7W7pkyRKVbV4TEr/75JNP/D2EAoUrdQAAABagqAMAALAARR0AAIAFrOipM3vgnnjiCZXNHroaNWrk6Xz79u1TeerUqSr7e421QJFx/SBzLSFzThcsWKDy+++/r/KZM2dUbtWqlcqDBg1S+b777lP5zjvvVPno0aMqm30bZu8O7Gf2fdatW1dlc400eG7lypUqFyrk2XWHf/7zn94cDgJAZGSkv4dQoHClDgAAwAIUdQAAABYIiLdfq1SponLDhg1VXrRokcr169fP0/ni4uJUnjVrlsrmllAsWeJ9hQsXVvmFF15Q2dyGKzU1VeU6dep4dD7zbZvY2FiVJ02a5NHrwT5mi4Cnbw0iM3M7vo4dO6ps/m29fv26yosXL1b55MmT3hscAkKtWrX8PYQChb9KAAAAFqCoAwAAsABFHQAAgAUKTE9d+fLlXbeXL1+ujpl9F3l9D93sn5ozZ47K5vIVV65cydP5kLXdu3e7bu/du1cda968udvnmkuemH2XJnPJk3Xr1qns623IYJ/WrVurvGrVKv8MJICVLVtW5ey2aDx+/LjKUVFR3h4SAsyXX36pstnrGmw971ypAwAAsABFHQAAgAUo6gAAACyQbz11LVu2VHn8+PEqt2jRwnW7WrVqeTrX5cuXVTa3mJo2bZrKv/76a57Oh9xJTk523X788cfVseeff17l1157zaPXnj9/vspLly5V+YcffvDo9QBzmzAA/hcfH69yYmKiymYP/t13361ySkqKbwbmJ1ypAwAAsABFHQAAgAUo6gAAACyQbz11vXr1cpvd+e6771TeunWrymlpaSqb686dP38+x+eCf5w4cULlKVOmuM2Ar3300Ucq9+3b108jsVdCQoLK5hqibdq0yc/hwAJmz/yKFStUnjp1qsqjRo1S2aw3Ag1X6gAAACxAUQcAAGABijoAAAALhDiO4+TogazRFJByOL1ZYs4DE3MefJjz4MOcZ6106dIqr1+/XuWOHTuqvGnTJpWffvpplQvSOrY5mXOu1AEAAFiAog4AAMACFHUAAAAWoKfOcvRdBB/mPPgw58GHOc8Zs8fOXKduxIgRKt97770qF6R16+ipAwAACBIUdQAAABagqAMAALAAPXWWo+8i+DDnwYc5Dz7MefChpw4AACBIUNQBAABYgKIOAADAAjnuqQMAAEDBxZU6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWICiDgAAwAIUdQAAABagqAMAALCAdUVdUlKShISEyOzZs732mjt37pSQkBDZuXOn114T3sOcBx/mPPgw58GHOfdcgSjqVq1aJSEhIbJv3z5/D8UnNm3aJP3795datWpJ8eLFpV69ejJu3Dg5f/68v4fmN7bP+aFDh2Ts2LESEREhoaGhEhISIklJSf4ell/ZPucxMTESGRkpd9xxhxQrVkzuvPNO6dOnj8THx/t7aH5j+5zze56Z7XNu6tSpk4SEhMjIkSP9PRQRKSBFne2ee+45OXjwoAwcOFAWLFggnTt3lkWLFknr1q3lypUr/h4efGD37t2yYMECuXjxojRo0MDfw0E++O9//yvlypWTMWPGyJIlS2TEiBHyzTffSIsWLeTAgQP+Hh58gN/z4LZp0ybZvXu3v4ehFPH3AILBhg0bpH379uq+pk2bypAhQ2Tt2rUyfPhw/wwMPtOjRw85f/68lCpVSmbPni379+/395DgY5MmTcp03/Dhw+XOO++UpUuXyrJly/wwKvgSv+fB6+rVqzJu3DiZMGFClr/7/hIwV+quX78ukyZNkqZNm0qZMmWkRIkS8uCDD0psbOwtn/P2229LeHi4hIWFSbt27bJ8GyQhIUH69Okj5cuXl9DQUGnWrJl88MEH2Y7n8uXLkpCQIKdPn872sWZBJyLSq1cvERE5ePBgts8PVoE85+XLl5dSpUpl+zhogTznWalcubIUL148qFstshPIc87vee4E8pzf9NZbb0l6erpERUXl+Dn5IWCKutTUVFmxYoW0b99eZs6cKVOmTJGUlBSJjIzM8v+OoqOjZcGCBfLiiy/KxIkTJT4+Xh5++GE5efKk6zHffvuttGrVSg4ePCgvv/yyzJkzR0qUKCE9e/aUmJgYt+PZs2ePNGjQQBYtWpSrr+eXX34REZGKFSvm6vnBwLY5R/ZsmPPz589LSkqK/Pe//5Xhw4dLamqqdOjQIcfPDzY2zDk8E+hzfvToUZkxY4bMnDlTwsLCPPrafc4pAFauXOmIiLN3795bPiYtLc25du2auu/cuXNOlSpVnGeeecZ135EjRxwRccLCwpzk5GTX/XFxcY6IOGPHjnXd16FDB6dRo0bO1atXXfelp6c7ERERTp06dVz3xcbGOiLixMbGZrpv8uTJufmSnWHDhjmFCxd2vv/++1w9P9AF05zPmjXLERHnyJEjHj3PNsEy5/Xq1XNExBERp2TJks5rr73m/Pbbbzl+vk2CZc4dh9/zm4Jhzvv06eNERES4sog4L774Yo6e62sBc6WucOHCctttt4mISHp6upw9e1bS0tKkWbNm8vXXX2d6fM+ePaVatWqu3KJFC2nZsqVs27ZNRETOnj0rO3bskH79+snFixfl9OnTcvr0aTlz5oxERkZKYmKiHD9+/Jbjad++vTiOI1OmTPH4a/nrX/8q7733nowbN07q1Knj8fODhU1zjpyxYc5XrlwpH3/8sSxZskQaNGggV65ckd9++y3Hzw82Nsw5PBPIcx4bGysbN26UefPmefZF55OA+qDE6tWrZc6cOZKQkCA3btxw3V+zZs1Mj82qWKpbt66sX79eRER++OEHcRxHXn/9dXn99dezPN+pU6fUD5I3fPnllzJs2DCJjIyUqVOnevW1bWTDnMMzgT7nrVu3dt0eMGCA61OR3lxryzaBPufwXCDOeVpamowePVoGDRokzZs3z9Nr+UrAFHVr1qyRoUOHSs+ePWX8+PFSuXJlKVy4sEyfPl0OHz7s8eulp6eLiEhUVJRERkZm+ZjatWvnacymAwcOSI8ePeSee+6RDRs2SJEiAfPt9wsb5hyesW3Oy5UrJw8//LCsXbuWou4WbJtzZC9Q5zw6OloOHToky5cvz7Qe4cWLFyUpKcn14Sh/CZiqYsOGDVKrVi3ZtGmThISEuO6fPHlylo9PTEzMdN/3338vNWrUEBGRWrVqiYhI0aJFpWPHjt4fsOHw4cPSuXNnqVy5smzbtk1Klizp83MGukCfc3jOxjm/cuWKXLhwwS/nDgQ2zjncC9Q5P3r0qNy4cUMeeOCBTMeio6MlOjpaYmJipGfPnj4bQ3YCqqdORMRxHNd9cXFxt1z4b/Pmzeo99D179khcXJx06dJFRH5faqB9+/ayfPlyOXHiRKbnp6SkuB2PJx+B/uWXX+SRRx6RQoUKySeffCKVKlXK9jkI7DlH7gTynJ86dSrTfUlJSfLZZ59Js2bNsn1+sArkOUfuBOqcDxgwQGJiYjL9ExHp2rWrxMTESMuWLd2+hq8VqCt177//vnz88ceZ7h8zZox0795dNm3aJL169ZJu3brJkSNHZNmyZdKwYUO5dOlSpufUrl1b2rRpIyNGjJBr167JvHnzpEKFCvLSSy+5HrN48WJp06aNNGrUSJ599lmpVauWnDx5Unbv3i3JycluV4Hfs2ePPPTQQzJ58uRsmys7d+4sP/74o7z00kuya9cu2bVrl+tYlSpVpFOnTjn47tjJ1jm/cOGCLFy4UEREvvrqKxERWbRokZQtW1bKli1bYLaU8Qdb57xRo0bSoUMHady4sZQrV04SExPlvffekxs3bsiMGTNy/g2ykK1zzu/5rdk45/Xr15f69etneaxmzZp+vULn4odP3GZy8yPQt/p37NgxJz093Zk2bZoTHh7uFCtWzGnSpImzdetWZ8iQIU54eLjrtW5+BHrWrFnOnDlznLvuusspVqyY8+CDDzoHDhzIdO7Dhw87gwcPdqpWreoULVrUqVatmtO9e3dnw4YNrsfk9SPQ7r62du3a5eE7F7hsn/ObY8rqX8axBxPb53zy5MlOs2bNnHLlyjlFihRx7rjjDmfAgAHOf/7zn7x82wKa7XPO73lmts95VqQALWkS4jgZrn8CAAAgIAVMTx0AAABujaIOAADAAhR1AAAAFqCoAwAAsABFHQAAgAUo6gAAACxAUQcAAGCBHO8okXF/NgSOvCxDyJwHJuY8+DDnwYc5Dz45mXOu1AEAAFiAog4AAMACFHUAAAAWoKgDAACwAEUdAACABSjqAAAALEBRBwAAYAGKOgAAAAtQ1AEAAFiAog4AAMACFHUAAAAWoKgDAACwAEUdAACABSjqAAAALEBRBwAAYIEi/h4A4Avz589XefTo0SrHx8er3L17d5V/+ukn3wwMAGClzz77TOWQkBCVH374YZ+PgSt1AAAAFqCoAwAAsABFHQAAgAXoqctCqVKlVC5ZsqTK3bp1U7lSpUoqz507V+Vr1655cXTISo0aNVQeOHCgyunp6So3aNBA5fr166tMT13BV7duXZWLFi2qctu2bVVesmSJyubPRF5t2bLFdXvAgAHq2PXr1716LvzOnPOIiAiVp02bpvIDDzzg8zEheLz99tsqmz9/0dHR+TkcEeFKHQAAgBUo6gAAACxAUQcAAGCBoOypM/uvJkyYoHLr1q1Vvueeezx6/dtvv11lc400eF9KSorKX3zxhco9evTIz+HAC/7whz+oPHToUJX79u2rcqFC+v9R77jjDpXNHjrHcfI4Qi3jz9iyZcvUsT/+8Y8qp6amevXcwapMmTIqx8bGqvzLL7+oXLVqVbfHgezMmDHDdft//ud/1LEbN26obK5blx+4UgcAAGABijoAAAALWPn2q7k8hfnWx1NPPaVyWFiYyubWHseOHVP54sWLKpvLY/Tr10/ljEspJCQk3GLUyItff/1VZZYkCXzTp09XuWvXrn4aiecGDx6s8nvvvafyV199lZ/DCVrm2628/Yq8atWqleu2uaTOrl27VF6/fn2+jCkjrtQBAABYgKIOAADAAhR1AAAAFgjInjrzY+wzZ85UuX///iqb235lJzExUeXIyEiVzffRzT65ihUrus3wvrJly6p83333+Wcg8Jrt27ernF1P3alTp1Q2+9jMJU+y2ybM3PKnXbt2bh+Pgsfsj0bgM7f/e/XVV1V+4oknVD579myezme+XsYlzg4fPqyORUVF5elc3sCVOgAAAAtQ1AEAAFiAog4AAMACAdlT16tXL5WHDx+ep9cz3xfv1KmTyuY6dbVr187T+eB9xYsXV7l69eoePb958+Yqm32SrHuX/5YuXary5s2b3T7e3KInr2uQlS5dWuX4+HiVzW3IMjLHum/fvjyNBbljbgUXGhrqp5HAW9555x2V69Spo3LDhg1VNteO89Qrr7yicoUKFVy3n332WXXswIEDeTqXN3ClDgAAwAIUdQAAABagqAMAALBAQPbU9e3b16PHJyUlqbx3716VJ0yYoLLZQ2cy93qF//38888qr1q1SuUpU6a4fb55/Pz58yovWrQolyNDbqWlpamc3e+lt5nrU5YrVy7Hz01OTlb52rVrXhkT8qZZs2Yq/+tf//LTSJBbly9fVtnbfZONGzdWOTw8XOWM61sWxB5NrtQBAABYgKIOAADAAhR1AAAAFgjInjpzbZjnnntO5U8//VTlH374QWVzj0hPValSJU/Ph++98cYbKmfXUwcMGDBAZfPvTFhYWI5fa9KkSV4ZE9wz+y4vXLigsrlP+N133+3zMcG7zL/ljRo1UvngwYMqe7pWXIkSJVQ2e+zNNVAz9mFu2LDBo3PlB67UAQAAWICiDgAAwAIUdQAAABYIyJ46c02y/O6Xat26db6eD3lXqJD+/5eMaw0hODz11FMqv/zyyyqbezoXLVrUo9ffv3+/67a5Dy18w1xP8ssvv1S5e/fu+TgaeMNdd92lstnbavZRjhw5UuWUlBSPzjd37lyVzXVwzXrjgQce8Oj18xtX6gAAACxAUQcAAGABijoAAAALBGRPXV6NHj1aZXOdmuyY6+SY/vnPf6q8e/duj14f3mf20Jn7BaLgqVGjhsqDBg1SuWPHjh69Xps2bVT29GcgNTVVZbMnb9u2ba7bV65c8ei1gWB1zz33qBwTE6NyxYoVVV64cKHKn3/+uUfni4qKUnno0KFuHz916lSPXt/fuFIHAABgAYo6AAAAC1DUAQAAWMCKnjpzb7aGDRuqPHnyZJW7du3q9vU8XdPMXMfm6aefVvm3335z+3wAmXtrPvjgA5WrV6+en8PJxFwD7Z133vHTSJBbFSpU8PcQgk6RIrrMGDhwoMrvvfeeytn999dcJ3bixIkqm+vOlS9fXmVzHbqQkBCVo6OjVV6+fLkEEq7UAQAAWICiDgAAwAIUdQAAABYIiJ46cw/GJk2aqLxx40aVb7/9dpXNNaPMHjhzHbnOnTurbPbsmcyegccff1zl+fPnu25fv37d7WsB+J3Z62JmT+V1/19zH9EuXbqo/NFHH+VuYMg3PXr08PcQgs6AAQNUXrFihcrmepHm7+UPP/ygcrNmzdzmxx57TOVq1aqpbNYH5l6xzzzzjAQyrtQBAABYgKIOAADAAhR1AAAAFiiQPXW33XabymaP26ZNm9w+/09/+pPKO3bsUPmrr75S2VzHxny8uX6WqVKlSipPnz5d5aNHj7pub968WR27du2a29eGd3jaT9W2bVuVFy1a5PUxQYuPj1e5ffv2KpvrW33yyScqX716NU/nHzZsmMqjRo3K0+sh/8XGxqps9kEif/Tv3991e+XKlerYjRs3VD5//rzKTz75pMrnzp1Tec6cOSq3a9dOZbPHzuzFNXv4zL1ljx07prL5d+jw4cNSkHGlDgAAwAIUdQAAABagqAMAALBAiGO+wXyrB+ZxjajsZFyL7s9//rM6Nn78eLfPNdeHGjRokMrme/ZmD9y2bdtUvv/++1U215Z76623VDZ77sx1cjL6xz/+ofLMmTNVNvsHTPv373d73JTD6c2Sr+c8P5n773r6fbn33ntV/u677/I8Jl9hznOnTJkyKp85c8bt4x999FGV/blOHXP+u969e6v8v//7vyqba5aa+4T/9NNPvhmYDxTkOc/Ylx4eHq6OvfnmmyqbPXfZMefM3JvV3Bs2u54601//+leVBw8e7NH4fCknc86VOgAAAAtQ1AEAAFiAog4AAMACflunrnDhwiq/8cYbrttRUVHq2K+//qryyy+/rPK6detUNnvozHVrzDXHzL1kExMTVR4xYoTK5lpIpUuXVjkiIkLlp556ynXb3Htw+/bt4o65Zk7NmjXdPh5ZW7ZsmcrPP/+8R89/7rnnVP7jH/+Y1yGhgImMjPT3EJBHaWlpbo+b/VXFihXz5XCC1pYtW1y3zXVlzf+mecpcVy67dWSfeOIJlc31ME3Jycm5G1gBwZU6AAAAC1DUAQAAWICiDgAAwAJ+66kze5Qy9tFdvnxZHTP7nz799FOVW7VqpfLTTz+tcpcuXVQOCwtT2VwXz1w3J7segNTUVJU//vjjW2bz/X1znzvT2LFj3R5HziQkJPh7CBC9HuUjjzyijpl7Lptrinmb+Xdi/vz5Pj0ffC9jL5dI5t/7+vXrq2z2xr7wwgs+GVew8ebvkrl+ZN++fVU2e9rNvVnXr1/vtbEEAq7UAQAAWICiDgAAwAIUdQAAABbw296vJ06cUDnjfqzXrl1Tx8y+iBIlSqhcu3Ztj849ZcoUladPn66yuU9oICvI+wP60/fff6/y3Xff7fbxhQrp//8xf+bMPg5/Kkhz3qZNG5VfffVV1+1OnTqpY+YajHldz6p8+fIqd+3aVeWFCxeqXKpUKbevZ/b4mWtOmutX5qeCNOcFybx581Q2+yirVKmi8tWrV309JK8JljmfOHGiyhnXtBURSUlJUbl58+YqB/q6cxmx9ysAAECQoKgDAACwgN+WNPnll19Uzvj2q7l1y3333ef2tbZt26byF198ofLmzZtVTkpKUtmmt1uRM99++63KtWrVcvv49PR0Xw7HWuaWfO629HnppZdUvnjxYp7Obb69e//996uc3VsZO3fuVHnp0qUq+/PtVuSOOefXr1/300hwK+Hh4SoPHz5cZXMO33nnHZVters1N7hSBwAAYAGKOgAAAAtQ1AEAAFjAbz11bdu2Vblnz56u22bvy6lTp1R+//33VT537pzK9EkgO2YfxqOPPuqnkeCmESNG5Ov5zL8rH374ocpjxoxROZCWu0DWzC2lHnvsMZVjYmLyczjIwvbt21U2e+zWrFmj8uTJk30+pkDClToAAAALUNQBAABYgKIOAADAAn7bJgz5I1i2kvGU2aexdetWlRs0aKCy+b2oW7euymwTlrXGjRurPGrUKNftIUOGePVc5hxcvnxZ5S+//FJls68yPj7eq+PJTwVpzguSn3/+WeVy5cqp3KRJE5XNLSkLMlvnPLttwfr27atyMPVBsk0YAABAkKCoAwAAsABFHQAAgAXoqbOcrX0XuLWCPOcZ93UeOnSoOvbmm2+qbPY/mXs4m+tZbdmyRWVzf2mbFeQ596d169apbPbK9ujRQ+WffvrJ52PyFuY8+NBTBwAAECQo6gAAACxAUQcAAGABeuosR99F8GHOgw9zHnyY8+BDTx0AAECQoKgDAACwAEUdAACABSjqAAAALEBRBwAAYAGKOgAAAAtQ1AEAAFiAog4AAMACFHUAAAAWoKgDAACwAEUdAACABXK89ysAAAAKLq7UAQAAWICiDgAAwAIUdQAAABagqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWOD/AN0W3QAP3VaMAAAAAElFTkSuQmCC"
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Get the first 10 samples\n",
"dataiter = iter(trainset)\n",
"images, labels = [], []\n",
"\n",
"for i in range(10):\n",
" image, label = next(dataiter)\n",
" images.append(image)\n",
" labels.append(label)\n",
"\n",
"# Plot the samples\n",
"fig, axes = plt.subplots(2, 5)\n",
"\n",
"for ax, img, lbl in zip(axes.ravel(), images, labels):\n",
" ax.imshow(img.squeeze().numpy(), cmap='gray')\n",
" ax.set_title(f'Label: {lbl}')\n",
" ax.axis('off')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:40:28.606350630Z",
"start_time": "2023-10-03T12:40:28.277928820Z"
}
},
"id": "87c6eae807f51118"
},
{
"cell_type": "markdown",
"source": [
"Define the MLP and Convolutional Autoencoder"
],
"metadata": {
"collapsed": false
},
"id": "e4e25962ef8e5b0d"
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [],
"source": [
"class Autoencoder(nn.Module):\n",
" def __init__(self, input_size, hidden_size, type='mlp'):\n",
" super(Autoencoder, self).__init__()\n",
" # type of autoencoder\n",
" self.type = type\n",
" if self.type == 'mlp':\n",
" self.encoder = nn.Sequential(\n",
" nn.Linear(input_size, hidden_size),\n",
" nn.ReLU(True))\n",
" self.decoder = nn.Sequential(\n",
" nn.Linear(hidden_size, input_size),\n",
" nn.ReLU(True),\n",
" nn.Sigmoid()\n",
" )\n",
" elif self.type == 'cnn':\n",
" # Encoder module\n",
" self.encoder = nn.Sequential(\n",
" nn.Conv2d(in_channels=input_size, out_channels=hidden_size//2, kernel_size=3, stride=2, padding=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(in_channels=hidden_size//2, out_channels=hidden_size, kernel_size=3, stride=2, padding=1),\n",
" nn.ReLU()\n",
" )\n",
" # Decoder module\n",
" self.decoder = nn.Sequential(\n",
" nn.ConvTranspose2d(in_channels=hidden_size, out_channels=hidden_size//2, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(in_channels=hidden_size//2, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
" nn.Sigmoid() # Sigmoid to ensure the output is between 0 and 1\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.encoder(x)\n",
" x = self.decoder(x)\n",
" return x"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:40:29.561602021Z",
"start_time": "2023-10-03T12:40:29.559204154Z"
}
},
"id": "26f2513d92b78e1e"
},
{
"cell_type": "markdown",
"source": [
"Check if GPU support is available"
],
"metadata": {
"collapsed": false
},
"id": "91a01313b4d95274"
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
}
],
"source": [
"# device\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:40:30.834696308Z",
"start_time": "2023-10-03T12:40:30.827970016Z"
}
},
"id": "67006b35b75d8dff"
},
{
"cell_type": "markdown",
"source": [
"Define the training function"
],
"metadata": {
"collapsed": false
},
"id": "8eebf70cb27640d5"
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [],
"source": [
"# Define the training function\n",
"def train(model, train_loader, optimizer, criterion, epoch, verbose=True):\n",
" model.train()\n",
" train_loss = 0\n",
" for i, (data, _) in enumerate(train_loader):\n",
" # check the type of autoencoder and modify the input data accordingly\n",
" if model.type == 'mlp':\n",
" data = data.view(data.size(0), -1)\n",
" data = data.to(device)\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = criterion(output, data)\n",
" loss.backward()\n",
" train_loss += loss.item()\n",
" optimizer.step()\n",
" train_loss /= len(train_loader.dataset)\n",
" if verbose:\n",
" print(f'{model.type}====> Epoch: {epoch} Average loss: {train_loss:.4f}') \n",
" return train_loss"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:40:31.675127463Z",
"start_time": "2023-10-03T12:40:31.655370005Z"
}
},
"id": "5f96f7be13984747"
},
{
"cell_type": "markdown",
"source": [
"The evaluation functions for the linear classification and clustering tasks"
],
"metadata": {
"collapsed": false
},
"id": "5f6386edcab6b1e4"
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [],
"source": [
"# Extract encoded representations for a given loader\n",
"def extract_features(loader, model):\n",
" features = []\n",
" labels = []\n",
" model.eval()\n",
" with torch.no_grad():\n",
" for data in loader:\n",
" img, label = data\n",
" if model.type == 'mlp':\n",
" img = img.view(img.size(0), -1)\n",
" img = img.to(device)\n",
" feature = model.encoder(img)\n",
" if model.type == 'cnn':\n",
" feature = feature.view(feature.size(0), -1) # Flatten the CNN encoded features\n",
" features.append(feature)\n",
" labels.append(label)\n",
" return torch.cat(features), torch.cat(labels)\n",
"\n",
"# Define the loss test function\n",
"def test_loss(model, test_loader, criterion):\n",
" model.eval()\n",
" eval_loss = 0\n",
" with torch.no_grad():\n",
" for i, (data, _) in enumerate(test_loader):\n",
" # check the type of autoencoder and modify the input data accordingly\n",
" if model.type == 'mlp':\n",
" data = data.view(data.size(0), -1)\n",
" data = data.to(device)\n",
" output = model(data)\n",
" eval_loss += criterion(output, data).item()\n",
" eval_loss /= len(test_loader.dataset)\n",
" print('====> Test set loss: {:.4f}'.format(eval_loss))\n",
" return eval_loss\n",
"\n",
"# Define the linear classification test function\n",
"def test_linear(encoded_train, train_labels, encoded_test, test_labels):\n",
" train_features_np = encoded_train.cpu().numpy()\n",
" train_labels_np = train_labels.cpu().numpy()\n",
" test_features_np = encoded_test.cpu().numpy()\n",
" test_labels_np = test_labels.cpu().numpy()\n",
" \n",
" # Apply logistic regression on train features and labels\n",
" logistic_regression = LogisticRegression(random_state=0, max_iter=100).fit(train_features_np, train_labels_np)\n",
" print(f\"Train accuracy: {logistic_regression.score(train_features_np, train_labels_np)}\")\n",
" # Apply logistic regression on test features and labels\n",
" test_accuracy = logistic_regression.score(test_features_np, test_labels_np)\n",
" print(f\"Test accuracy: {test_accuracy}\")\n",
" return test_accuracy\n",
"\n",
"\n",
"def test_clustering(encoded_features, true_labels):\n",
" encoded_features_np = encoded_features.cpu().numpy()\n",
" true_labels_np = true_labels.cpu().numpy()\n",
" \n",
" # Apply k-means clustering\n",
" kmeans = KMeans(n_clusters=10, n_init=10, random_state=0).fit(encoded_features_np)\n",
" cluster_labels = kmeans.labels_\n",
" \n",
" # Evaluate clustering results using Adjusted Rand Index\n",
" ari_score = adjusted_rand_score(true_labels_np, cluster_labels)\n",
" print(f\"Clustering ARI score: {ari_score}\")\n",
" return ari_score\n",
"\n",
"def knn_classifier(encoded_train, train_labels, encoded_test, test_labels, k=5):\n",
" encoded_train_np = encoded_train.cpu().numpy()\n",
" encoded_test_np = encoded_test.cpu().numpy()\n",
" train_labels_np = train_labels.cpu().numpy()\n",
" test_labels_np = test_labels.cpu().numpy()\n",
" \n",
" # Apply k-nearest neighbors classification\n",
" knn = KNeighborsClassifier(n_neighbors=k).fit(encoded_train_np, train_labels_np)\n",
" accuracy_score = knn.score(encoded_test_np, test_labels_np)\n",
" print(f\"KNN accuracy: {accuracy_score}\")\n",
" return accuracy_score\n",
"\n",
"def test(model, train_loader, test_loader, criterion):\n",
" # Extract features once for all tests\n",
" encoded_train, train_labels = extract_features(train_loader, model)\n",
" encoded_test, test_labels = extract_features(test_loader, model)\n",
" print(f\"{model.type} Autoencoder\")\n",
" results = {\n",
" 'reconstruction_loss': test_loss(model, test_loader, criterion),\n",
" 'linear_classification_accuracy': test_linear(encoded_train, train_labels, encoded_test, test_labels),\n",
" 'knn_classification_accuracy': knn_classifier(encoded_train, train_labels, encoded_test, test_labels),\n",
" 'clustering_ari_score': test_clustering(encoded_test, test_labels)\n",
" }\n",
" \n",
" # Save results to a log file\n",
" with open(\"evaluation_results.log\", \"w\") as log_file:\n",
" for key, value in results.items():\n",
" log_file.write(f\"{key}: {value}\")\n",
" \n",
" return results\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:40:32.662180572Z",
"start_time": "2023-10-03T12:40:32.657785583Z"
}
},
"id": "b2c4483492fdd427"
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MLP AE parameters: 201616\n",
"CNN AE parameters: 148865\n"
]
}
],
"source": [
"# Define the training parameters for the fully connected MLP Autoencoder\n",
"batch_size = 32\n",
"epochs = 5\n",
"hidden_size = 128\n",
"train_frequency = epochs\n",
"test_frequency = epochs\n",
"\n",
"# Create the fully connected MLP Autoencoder\n",
"input_size = trainset.data.shape[1] * trainset.data.shape[2]\n",
"ae = Autoencoder(input_size, hidden_size, type='mlp').to(device)\n",
"input_size=1\n",
"cnn_ae = Autoencoder(input_size, hidden_size, type='cnn').to(device)\n",
"# print the models' number of parameters\n",
"print(f\"MLP AE parameters: {sum(p.numel() for p in ae.parameters())}\")\n",
"print(f\"CNN AE parameters: {sum(p.numel() for p in cnn_ae.parameters())}\")\n",
"\n",
"# Define the loss function and optimizer\n",
"criterion = nn.MSELoss()\n",
"optimizer = optim.Adam(ae.parameters(), lr=1e-3)\n",
"optimizer_cnn = optim.Adam(cnn_ae.parameters(), lr=1e-3)\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:40:33.593858793Z",
"start_time": "2023-10-03T12:40:33.153759806Z"
}
},
"id": "bcb22bc5af9fb014"
},
{
"cell_type": "code",
"execution_count": 10,
"outputs": [],
"source": [
"# Create the train and test dataloaders\n",
"train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
"test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:40:34.157176359Z",
"start_time": "2023-10-03T12:40:34.153720678Z"
}
},
"id": "f1626ce45bb25883"
},
{
"cell_type": "code",
"execution_count": 11,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 80%|████████ | 4/5 [01:02<00:15, 15.79s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mlp====> Epoch: 5 Average loss: 0.0598\n",
"cnn====> Epoch: 5 Average loss: 0.0260\n",
"mlp Autoencoder\n",
"====> Test set loss: 0.0598\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" n_iter_i = _check_optimize_result(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train accuracy: 0.25626666666666664\n",
"Test accuracy: 0.2649\n",
"KNN accuracy: 0.2295\n",
"Clustering ARI score: 0.0614873771495409\n",
"cnn Autoencoder\n",
"====> Test set loss: 0.0260\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" n_iter_i = _check_optimize_result(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train accuracy: 0.9316166666666666\n",
"Test accuracy: 0.9278\n",
"KNN accuracy: 0.9639\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 5/5 [06:56<00:00, 83.22s/it] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Clustering ARI score: 0.3909294873624941\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"test_mlp = []\n",
"test_cnn = []\n",
"# Train the model\n",
"for epoch in tqdm(range(1, epochs + 1)):\n",
" verbose = True if epoch % train_frequency == 0 else False\n",
" train(ae, train_loader, optimizer, criterion, epoch, verbose)\n",
" train(cnn_ae, train_loader, optimizer_cnn, criterion, epoch, verbose)\n",
"\n",
" # test every n epochs\n",
" if epoch % test_frequency == 0:\n",
" restults_dic = test(ae, train_loader, test_loader, criterion)\n",
" test_mlp.append([restults_dic['reconstruction_loss'], restults_dic['linear_classification_accuracy'], restults_dic['knn_classification_accuracy'], restults_dic['clustering_ari_score']])\n",
" restults_dic = test(cnn_ae, train_loader, test_loader, criterion)\n",
" test_cnn.append([restults_dic['reconstruction_loss'], restults_dic['linear_classification_accuracy'], restults_dic['knn_classification_accuracy'], restults_dic['clustering_ari_score']])"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:47:30.814501313Z",
"start_time": "2023-10-03T12:40:34.720164326Z"
}
},
"id": "7472159fdc5f2532"
},
{
"cell_type": "markdown",
"source": [
"Compare the evaluation results of the MLP and CNN Autoencoders"
],
"metadata": {
"collapsed": false
},
"id": "10639256e342a159"
},
{
"cell_type": "code",
"execution_count": 12,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model Reconstruction Loss Linear Accuracy KNN Accuracy Clustering ARI \n",
"MLP AE 0.0598 0.2649 0.2295 0.0615 \n",
"CNN AE 0.0260 0.9278 0.9639 0.3909 \n"
]
}
],
"source": [
"print(f\"{'Model':<10} {'Reconstruction Loss':<20} {'Linear Accuracy':<20} {'KNN Accuracy':<20} {'Clustering ARI':<20}\")\n",
"print(f\"{'MLP AE':<10} {test_mlp[-1][0]:<20.4f} {test_mlp[-1][1]:<20.4f} {test_mlp[-1][2]:<20.4f} {test_mlp[-1][3]:<20.4f}\")\n",
"print(f\"{'CNN AE':<10} {test_cnn[-1][0]:<20.4f} {test_cnn[-1][1]:<20.4f} {test_cnn[-1][2]:<20.4f} {test_cnn[-1][3]:<20.4f}\")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:47:30.828062767Z",
"start_time": "2023-10-03T12:47:30.812448850Z"
}
},
"id": "50bb4c3c58af09ee"
},
{
"cell_type": "markdown",
"source": [
"Develop a linear classifier with fully connected layers"
],
"metadata": {
"collapsed": false
},
"id": "b9201d1403781706"
},
{
"cell_type": "code",
"execution_count": 13,
"outputs": [],
"source": [
"# Define the fully connected classifier for MNIST\n",
"class DenseClassifier(nn.Module):\n",
" def __init__(self, input_size=784, hidden_size=500, num_classes=10):\n",
" super(DenseClassifier, self).__init__()\n",
" self.type = 'mlp'\n",
" self.encoder = nn.Sequential(\n",
" nn.Linear(input_size, hidden_size),\n",
" nn.ReLU(True))\n",
" self.fc1 = nn.Linear(hidden_size, num_classes)\n",
"\n",
" def forward(self, x):\n",
" x = x.view(x.size(0), -1) # Flatten the input tensor\n",
" x = self.encoder(x)\n",
" x = self.fc1(x)\n",
" return x\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:47:30.833296890Z",
"start_time": "2023-10-03T12:47:30.819270525Z"
}
},
"id": "1612800950703181"
},
{
"cell_type": "markdown",
"source": [
"Develop a non-linear classifier with convolutional layers"
],
"metadata": {
"collapsed": false
},
"id": "35db4190e9c7f716"
},
{
"cell_type": "code",
"execution_count": 14,
"outputs": [],
"source": [
"# cnn classifier\n",
"class CNNClassifier(nn.Module):\n",
" def __init__(self, input_size=3, hidden_size=128, num_classes=10):\n",
" super(CNNClassifier, self).__init__()\n",
" self.type = 'cnn'\n",
" # Encoder (Feature extractor)\n",
" self.encoder = nn.Sequential(\n",
" nn.Conv2d(in_channels=input_size, out_channels=hidden_size//2, kernel_size=3, stride=2, padding=1),\n",
" nn.ReLU(),\n",
" nn.Conv2d(in_channels=hidden_size//2, out_channels=hidden_size, kernel_size=3, stride=2, padding=1),\n",
" nn.ReLU()\n",
" )\n",
" \n",
" # Classifier\n",
" # Here, for the sake of example, I'm assuming the spatial size of the encoder output \n",
" # is 7x7 for an input size of 28x28. You might want to adjust this if the spatial dimensions change.\n",
" self.classifier = nn.Sequential(\n",
" nn.Flatten(),\n",
" nn.Linear(hidden_size*7*7, hidden_size),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_size, num_classes),\n",
" nn.LogSoftmax(dim=1) # LogSoftmax is typically used with NLLLoss\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.encoder(x)\n",
" x = self.classifier(x)\n",
" return x\n",
" return x"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:47:30.916320521Z",
"start_time": "2023-10-03T12:47:30.828281294Z"
}
},
"id": "cb2dfcf75113fd0b"
},
{
"cell_type": "markdown",
"source": [
"Train and test functions for the non-linear classifier"
],
"metadata": {
"collapsed": false
},
"id": "7c3cf2371479da0c"
},
{
"cell_type": "code",
"execution_count": 15,
"outputs": [],
"source": [
"# Train for the classifier\n",
"def train_classifier(model, train_loader, optimizer, criterion, epoch, verbose=True):\n",
" model.train()\n",
" train_loss = 0\n",
" correct = 0 \n",
" for i, (data, target) in enumerate(train_loader):\n",
" if model.type == 'cnn':\n",
" data = data.to(device)\n",
" else:\n",
" data = data.view(data.size(0), -1)\n",
" data = data.to(device)\n",
" target = target.to(device)\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = criterion(output, target)\n",
" loss.backward()\n",
" train_loss += loss.item()\n",
" optimizer.step()\n",
" # Calculate correct predictions for training accuracy\n",
" pred = output.argmax(dim=1, keepdim=True)\n",
" correct += pred.eq(target.view_as(pred)).sum().item()\n",
"\n",
" train_loss /= len(train_loader.dataset)\n",
" train_accuracy = 100. * correct / len(train_loader.dataset)\n",
" if verbose:\n",
" print(f'{model.type}====> Epoch: {epoch} Average loss: {train_loss:.4f}')\n",
" print(f'{model.type}====> Epoch: {epoch} Training accuracy: {train_accuracy:.2f}%')\n",
" return train_loss\n",
"\n",
"\n",
"def test_classifier(model, test_loader, criterion):\n",
" model.eval()\n",
" eval_loss = 0\n",
" correct = 0\n",
" with torch.no_grad():\n",
" for i, (data, target) in enumerate(test_loader):\n",
" if model.type == 'cnn':\n",
" data = data.to(device)\n",
" else:\n",
" data = data.view(data.size(0), -1)\n",
" data = data.to(device)\n",
" target = target.to(device)\n",
" output = model(data)\n",
" eval_loss += criterion(output, target).item()\n",
" pred = output.argmax(dim=1, keepdim=True)\n",
" correct += pred.eq(target.view_as(pred)).sum().item()\n",
" eval_loss /= len(test_loader.dataset)\n",
" print('====> Test set loss: {:.4f}'.format(eval_loss))\n",
" accuracy = correct / len(test_loader.dataset)\n",
" print('====> Test set accuracy: {:.4f}'.format(accuracy))\n",
" return accuracy\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:47:30.924063759Z",
"start_time": "2023-10-03T12:47:30.875486950Z"
}
},
"id": "ac980d25bd8a3dd3"
},
{
"cell_type": "code",
"execution_count": 16,
"outputs": [],
"source": [
"# Define the training parameters for the fully connected classifier\n",
"batch_size = 32\n",
"epochs = 5\n",
"learning_rate = 1e-3\n",
"hidden_size = 128\n",
"num_classes = 10\n",
"train_frequency = epochs\n",
"test_frequency = epochs\n",
"# Create the fully connected classifier\n",
"input_size = trainset.data.shape[1] * trainset.data.shape[2]\n",
"classifier = DenseClassifier(input_size, hidden_size, num_classes).to(device)\n",
"input_size = 1\n",
"cnn_classifier = CNNClassifier(input_size, hidden_size, num_classes).to(device)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:47:30.924544974Z",
"start_time": "2023-10-03T12:47:30.875705556Z"
}
},
"id": "dff05e622dcfd774"
},
{
"cell_type": "code",
"execution_count": 17,
"outputs": [],
"source": [
"# Define the loss function and optimizer\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(classifier.parameters(), lr=learning_rate)\n",
"optimizer_cnn = optim.Adam(cnn_classifier.parameters(), lr=learning_rate)\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:47:30.924773429Z",
"start_time": "2023-10-03T12:47:30.875787620Z"
}
},
"id": "3104345cdee0eb00"
},
{
"cell_type": "markdown",
"source": [
"Train the non-linear classifiers"
],
"metadata": {
"collapsed": false
},
"id": "e1fed39be2f04745"
},
{
"cell_type": "code",
"execution_count": 18,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 80%|████████ | 4/5 [02:10<00:32, 32.47s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mlp====> Epoch: 5 Average loss: 0.0030\n",
"mlp====> Epoch: 5 Training accuracy: 97.08%\n",
"cnn====> Epoch: 5 Average loss: 0.0005\n",
"cnn====> Epoch: 5 Training accuracy: 99.49%\n",
"====> Test set loss: 0.0035\n",
"====> Test set accuracy: 0.9686\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 5/5 [02:48<00:00, 33.62s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"====> Test set loss: 0.0013\n",
"====> Test set accuracy: 0.9880\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Train the classifier\n",
"for epoch in tqdm(range(1, epochs + 1)):\n",
" verbose = True if epoch % train_frequency == 0 else False\n",
" train_classifier(classifier, train_loader, optimizer, criterion, epoch, verbose)\n",
" train_classifier(cnn_classifier, train_loader, optimizer_cnn, criterion, epoch, verbose)\n",
"\n",
" # test every n epochs\n",
" if epoch % test_frequency == 0:\n",
" test_acc = test_classifier(classifier, test_loader, criterion)\n",
" test_acc_cnn = test_classifier(cnn_classifier, test_loader, criterion)\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:50:19.005163249Z",
"start_time": "2023-10-03T12:47:30.875867176Z"
}
},
"id": "abc0c6ce338d40d9"
},
{
"cell_type": "markdown",
"source": [
"Load the encoder weights into the classifier"
],
"metadata": {
"collapsed": false
},
"id": "a06038f113d8434f"
},
{
"cell_type": "code",
"execution_count": 19,
"outputs": [
{
"data": {
"text/plain": "<All keys matched successfully>"
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# initialize the classifier with the encoder weights\n",
"classifier.encoder.load_state_dict(ae.encoder.state_dict())\n",
"cnn_classifier.encoder.load_state_dict(cnn_ae.encoder.state_dict())"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:50:19.005816667Z",
"start_time": "2023-10-03T12:50:18.994175691Z"
}
},
"id": "6a91d8894b70ef7c"
},
{
"cell_type": "markdown",
"source": [
"Transfer learning"
],
"metadata": {
"collapsed": false
},
"id": "aafa4a9ba7208647"
},
{
"cell_type": "code",
"execution_count": 20,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 80%|████████ | 4/5 [01:56<00:27, 27.00s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mlp====> Epoch: 5 Average loss: 0.0547\n",
"mlp====> Epoch: 5 Training accuracy: 38.00%\n",
"cnn====> Epoch: 5 Average loss: 0.0004\n",
"cnn====> Epoch: 5 Training accuracy: 99.57%\n",
"====> Test set loss: 0.0526\n",
"====> Test set accuracy: 0.4150\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 5/5 [02:25<00:00, 29.01s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"====> Test set loss: 0.0017\n",
"====> Test set accuracy: 0.9868\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# fine-tune the classifier\n",
"learning_rate = 1e-5\n",
"epochs = 5\n",
"train_frequency = epochs\n",
"test_frequency = epochs\n",
"optimizer_pretrained = optim.Adam(classifier.parameters(), lr=learning_rate)\n",
"optimizer_pretrained_cnn = optim.Adam(cnn_classifier.parameters(), lr=learning_rate)\n",
"for epoch in tqdm(range(1, epochs + 1)):\n",
" verbose = True if epoch % train_frequency == 0 else False\n",
" train_loss = train_classifier(classifier, train_loader, optimizer_pretrained, criterion, epoch, verbose)\n",
" train_loss_cnn = train_classifier(cnn_classifier, train_loader, optimizer_cnn, criterion, epoch, verbose)\n",
"\n",
" # test every n epochs\n",
" if epoch % test_frequency == 0:\n",
" test_acc_pretrained = test_classifier(classifier, test_loader, criterion)\n",
" test_acc_pretrained_cnn = test_classifier(cnn_classifier, test_loader, criterion)\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:52:44.085979647Z",
"start_time": "2023-10-03T12:50:19.003728990Z"
}
},
"id": "a60dd68f988a8249"
},
{
"cell_type": "markdown",
"source": [
"Compare the results of the linear probing with the results of the linear classifier"
],
"metadata": {
"collapsed": false
},
"id": "31577275b833707a"
},
{
"cell_type": "code",
"execution_count": 21,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model Linear Accuracy Non-linear accuracy Pretrained accuracy \n",
"MLP AE 0.2649 0.9686 0.4150 \n",
"CNN AE 0.9278 0.9880 0.9868 \n"
]
}
],
"source": [
"# print a table of the accuracies. compare the results with the results of the linear probing\n",
"print(f\"{'Model':<10} {'Linear Accuracy':<20} {'Non-linear accuracy':<20} {'Pretrained accuracy':<20}\")\n",
"print(f\"{'MLP AE':<10} {test_mlp[-1][1]:<20.4f} {test_acc:<20.4f} {test_acc_pretrained:<20.4f}\")\n",
"print(f\"{'CNN AE':<10} {test_cnn[-1][1]:<20.4f} {test_acc_cnn:<20.4f} {test_acc_pretrained_cnn:<20.4f}\")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:52:44.091206610Z",
"start_time": "2023-10-03T12:52:44.084572508Z"
}
},
"id": "40d0e7f3f13404c9"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "f38a1ab6951a694e"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}