1136 lines
52 KiB
Plaintext
1136 lines
52 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"| Credentials | |\n",
|
||
|
"|----|----------------------------------|\n",
|
||
|
"|Host | Montanuniversitaet Leoben |\n",
|
||
|
"|Web | https://cps.unileoben.ac.at |\n",
|
||
|
"|Mail | cps@unileoben.ac.at |\n",
|
||
|
"|Author | Fotios Lygerakis |\n",
|
||
|
"|Corresponding Authors | fotios.lygerakis@unileoben.ac.at |\n",
|
||
|
"|Last edited | 28.09.2023 |"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "ae041e151c5c2222"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"In the first part of this tutorial we will build a fully connected MLP Autoencoder on the MNIST dataset. Then we will perform linear probing on the encoder features to see how well they perform on a linear classification task."
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "1490260facaff836"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import torch\n",
|
||
|
"import torch.nn as nn\n",
|
||
|
"import torch.optim as optim\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"from torchvision import datasets, transforms\n",
|
||
|
"from sklearn.neighbors import KNeighborsClassifier\n",
|
||
|
"from sklearn.metrics import adjusted_rand_score\n",
|
||
|
"from sklearn.linear_model import LogisticRegression\n",
|
||
|
"from sklearn.cluster import KMeans\n",
|
||
|
"from tqdm import tqdm"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:40:27.590742473Z",
|
||
|
"start_time": "2023-10-03T12:40:25.356175335Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "fc18830bb6f8d534"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Set random seed"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "1bad4bd03deb5b7e"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": "<torch._C.Generator at 0x7f10889c50d0>"
|
||
|
},
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Set random seed\n",
|
||
|
"torch.manual_seed(0)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:40:27.592356147Z",
|
||
|
"start_time": "2023-10-03T12:40:27.568457489Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "27dd48e60ae7dd9e"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Load the MNIST dataset"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "cc7f167a33227e94"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Define the transformations\n",
|
||
|
"transform = transforms.Compose([transforms.ToTensor(),\n",
|
||
|
" transforms.Normalize((0.5,), (0.5,))])\n",
|
||
|
"# Download and load the training data\n",
|
||
|
"trainset = datasets.MNIST('data', download=True, train=True, transform=transform)\n",
|
||
|
"# Download and load the test data\n",
|
||
|
"testset = datasets.MNIST('data', download=True, train=False, transform=transform)\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:40:27.639417871Z",
|
||
|
"start_time": "2023-10-03T12:40:27.577605311Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "34248e8bc2678fd3"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Print some examples from the dataset"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "928dfac955d0d778"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": "<Figure size 640x480 with 10 Axes>",
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnUAAAFiCAYAAACQzC7qAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAxUElEQVR4nO3dd3TUVfrH8ScUSegdFCWA9BUF6REBBQxNBKkqTUH9oZTlEEQswK7SpEgvigJZ2MPyAwLKYmElqLhsgFXYjRKMSIQgQqgBqTHf3x8e5pfnJkwyyUwmc+f9Oodz5jPfmfne5Cbx8TvP3BviOI4jAAAACGiF/D0AAAAA5B1FHQAAgAUo6gAAACxAUQcAAGABijoAAAALUNQBAABYgKIOAADAAhR1AAAAFqCoAwAAsIB1RV1SUpKEhITI7NmzvfaaO3fulJCQENm5c6fXXhPew5wHH+Y8+DDnwYc591yBKOpWrVolISEhsm/fPn8PxSemTJkiISEhmf6Fhob6e2h+Y/uci4gcP35c+vXrJ2XLlpXSpUvLY489Jj/++KO/h+U3wTDnGXXq1ElCQkJk5MiR/h6K39g+54cOHZKxY8dKRESEhIaGSkhIiCQlJfl7WH5l+5yLiKxbt07uv/9+CQ0NlUqVKsmwYcPk9OnT/h6WiIgU8fcAgsnSpUulZMmSrly4cGE/jga+dOnSJXnooYfkwoUL8sorr0jRokXl7bfflnbt2sn+/fulQoUK/h4ifGjTpk2ye/dufw8DPrZ7925ZsGCBNGzYUBo0aCD79+/395DgY0uXLpUXXnhBOnToIHPnzpXk5GSZP3++7Nu3T+Li4vx+sYaiLh/16dNHKlas6O9hIB8sWbJEEhMTZc+ePdK8eXMREenSpYvcc889MmfOHJk2bZqfRwhfuXr1qowbN04mTJggkyZN8vdw4EM9evSQ8+fPS6lSpWT27NkUdZa7fv26vPLKK9K2bVvZvn27hISEiIhIRESEPProo/Luu+/KqFGj/DrGAvH2a05cv35dJk2aJE2bNpUyZcpIiRIl5MEHH5TY2NhbPuftt9+W8PBwCQsLk3bt2kl8fHymxyQkJEifPn2kfPnyEhoaKs2aNZMPPvgg2/FcvnxZEhISPLrk6jiOpKamiuM4OX5OMAvkOd+wYYM0b97cVdCJiNSvX186dOgg69evz/b5wSqQ5/ymt956S9LT0yUqKirHzwlmgTzn5cuXl1KlSmX7OGiBOufx8fFy/vx56d+/v6ugExHp3r27lCxZUtatW5ftuXwtYIq61NRUWbFihbRv315mzpwpU6ZMkZSUFImMjMzy/46io6NlwYIF8uKLL8rEiRMlPj5eHn74YTl58qTrMd9++620atVKDh48KC+//LLMmTNHSpQoIT179pSYmBi349mzZ480aNBAFi1alOOvoVatWlKmTBkpVaqUDBw4UI0FmQXqnKenp8t//vMfadasWaZjLVq0kMOHD8vFixdz9k0IMoE65zcdPXpUZsyYITNnzpSwsDCPvvZgFehzDs8F6pxfu3ZNRCTL3+2wsDD55ptvJD09PQffAR9yCoCVK1c6IuLs3bv3lo9JS0tzrl27pu47d+6cU6VKFeeZZ55x3XfkyBFHRJywsDAnOTnZdX9cXJwjIs7YsWNd93Xo0MFp1KiRc/XqVdd96enpTkREhFOnTh3XfbGxsY6IOLGxsZnumzx5crZf37x585yRI0c6a9eudTZs2OCMGTPGKVKkiFOnTh3nwoUL2T7fRjbPeUpKiiMizp///OdMxxYvXuyIiJOQkOD2NWxk85zf1KdPHyciIsKVRcR58cUXc/RcGwXDnN80a9YsR0ScI0eOePQ829g85ykpKU5ISIgzbNgwdX9CQoIjIo6IOKdPn3b7Gr4WMFfqChcuLLfddpuI/H4l5OzZs5KWlibNmjWTr7/+OtPje/bsKdWqVXPlFi1aSMuWLWXbtm0iInL27FnZsWOH9OvXTy5evCinT5+W06dPy5kzZyQyMlISExPl+PHjtxxP+/btxXEcmTJlSrZjHzNmjCxcuFCefPJJ6d27t8ybN09Wr14tiYmJsmTJEg+/E8EjUOf8ypUrIiJSrFixTMduNtHefAy0QJ1zEZHY2FjZuHGjzJs3z7MvOsgF8pwjdwJ1zitWrCj9+vWT1atXy5w5c+THH3+UL7/8Uvr37y9FixYVEf//bQ+Yok5EZPXq1XLvvfdKaGioVKhQQSpVqiR///vf5cKFC5keW6dOnUz31a1b1/Vx8x9++EEcx5HXX39dKlWqpP5NnjxZREROnTrls6/lySeflKpVq8o//vEPn53DBoE45zcvzd+8VJ/R1atX1WOQWSDOeVpamowePVoGDRqk+iiRM4E458ibQJ3z5cuXS9euXSUqKkruvvtuadu2rTRq1EgeffRRERG1woU/BMynX9esWSNDhw6Vnj17yvjx46Vy5cpSuHBhmT59uhw+fNjj17v5vndUVJRERkZm+ZjatWvnaczZueuuu+Ts2bM+PUcgC9Q5L1++vBQrVkxOnDiR6djN++644448n8dGgTrn0dHRcujQIVm+fHmmdcouXrwoSUlJUrlyZSlevHiez2WbQJ1z5F4gz3mZMmVky5YtcvToUUlKSpLw8HAJDw+XiIgIqVSpkpQtW9Yr58mtgCnqNmzYILVq1ZJNmzapT53crMJNiYmJme77/vvvpUaNGiLy+4cWRESKFi0qHTt29P6As+E4jiQlJUmTJk3y/dyBIlDnvFChQtKoUaMsF9+Mi4uTWrVq8Ym5WwjUOT969KjcuHFDHnjggUzHoqOjJTo6WmJiYqRnz54+G0OgCtQ5R+7ZMOfVq1eX6tWri4jI+fPn5d///rf07t07X87tTsC8/XpzoV4nw3IgcXFxt1zgc/Pmzeo99D179khcXJx06dJFREQqV64s7du3l+XLl2d5RSUlJcXteDz52HtWr7V06VJJSUmRzp07Z/v8YBXIc96nTx/Zu3evKuwOHTokO3bskL59+2b7/GAVqHM+YMAAiYmJyfRPRKRr164SExMjLVu2dPsawSpQ5xy5Z9ucT5w4UdLS0mTs2LG5er43Fagrde+//758/PHHme4fM2aMdO/eXTZt2iS9evWSbt26yZEjR2TZsmXSsGFDuXTpUqbn1K5dW9q0aSMjRoyQa9euybx586RChQry0ksvuR6zePFiadOmjTRq1EieffZZqVWrlpw8eVJ2794tycnJcuDAgVuOdc+ePfLQQw/J5MmTs22uDA8Pl/79+0ujRo0kNDRUdu3aJevWrZPGjRvL888/n/NvkIVsnfMXXnhB3n33XenWrZtERUVJ0aJFZe7cuVKlShUZN25czr9BFrJxzuvXry/169fP8ljNmjWD/gqdjXMuInLhwgVZuHChiIh89dVXIiKyaNEiKVu2rJQtWzaot4izdc5nzJgh8fHx0rJlSylSpIhs3rxZPv30U3nzzTcLRj9t/n/gNrObH4G+1b9jx4456enpzrRp05zw8HCnWLFiTpMmTZytW7c6Q4YMccLDw12vdfMj0LNmzXLmzJnj3HXXXU6xYsWcBx980Dlw4ECmcx8+fNgZPHiwU7VqVado0aJOtWrVnO7duzsbNmxwPSavH3sfPny407BhQ6dUqVJO0aJFndq1azsTJkxwUlNT8/JtC2i2z7njOM6xY8ecPn36OKVLl3ZKlizpdO/e3UlMTMzttyzgBcOcm4QlTaye85tjyupfxrEHE9vnfOvWrU6LFi2cUqVKOcWLF3datWrlrF+/Pi/fMq8KcRy2NwAAAAh0AdNTBwAAgFujqAMAALAARR0AAIAFKOoAAAAsQFEHAABgAYo6AAAAC1DUAQAAWCDHO0pk3J8NgSMvyxAy54GJOQ8+zHnwYc6DT07mnCt1AAAAFqCoAwAAsABFHQAAgAUo6gAAACxAUQcAAGABijoAAAALUNQBAABYgKIOAADAAhR1AAAAFqCoAwAAsABFHQAAgAUo6gAAACxAUQcAAGABijoAAAALUNQBAABYoIi/BwAUBE2bNlV55MiRKg8ePFjl6OholRcuXKjy119
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Get the first 10 samples\n",
|
||
|
"dataiter = iter(trainset)\n",
|
||
|
"images, labels = [], []\n",
|
||
|
"\n",
|
||
|
"for i in range(10):\n",
|
||
|
" image, label = next(dataiter)\n",
|
||
|
" images.append(image)\n",
|
||
|
" labels.append(label)\n",
|
||
|
"\n",
|
||
|
"# Plot the samples\n",
|
||
|
"fig, axes = plt.subplots(2, 5)\n",
|
||
|
"\n",
|
||
|
"for ax, img, lbl in zip(axes.ravel(), images, labels):\n",
|
||
|
" ax.imshow(img.squeeze().numpy(), cmap='gray')\n",
|
||
|
" ax.set_title(f'Label: {lbl}')\n",
|
||
|
" ax.axis('off')\n",
|
||
|
"\n",
|
||
|
"plt.tight_layout()\n",
|
||
|
"plt.show()"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:40:28.606350630Z",
|
||
|
"start_time": "2023-10-03T12:40:28.277928820Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "87c6eae807f51118"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Define the MLP and Convolutional Autoencoder"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "e4e25962ef8e5b0d"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class Autoencoder(nn.Module):\n",
|
||
|
" def __init__(self, input_size, hidden_size, type='mlp'):\n",
|
||
|
" super(Autoencoder, self).__init__()\n",
|
||
|
" # type of autoencoder\n",
|
||
|
" self.type = type\n",
|
||
|
" if self.type == 'mlp':\n",
|
||
|
" self.encoder = nn.Sequential(\n",
|
||
|
" nn.Linear(input_size, hidden_size),\n",
|
||
|
" nn.ReLU(True))\n",
|
||
|
" self.decoder = nn.Sequential(\n",
|
||
|
" nn.Linear(hidden_size, input_size),\n",
|
||
|
" nn.ReLU(True),\n",
|
||
|
" nn.Sigmoid()\n",
|
||
|
" )\n",
|
||
|
" elif self.type == 'cnn':\n",
|
||
|
" # Encoder module\n",
|
||
|
" self.encoder = nn.Sequential(\n",
|
||
|
" nn.Conv2d(in_channels=input_size, out_channels=hidden_size//2, kernel_size=3, stride=2, padding=1),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" nn.Conv2d(in_channels=hidden_size//2, out_channels=hidden_size, kernel_size=3, stride=2, padding=1),\n",
|
||
|
" nn.ReLU()\n",
|
||
|
" )\n",
|
||
|
" # Decoder module\n",
|
||
|
" self.decoder = nn.Sequential(\n",
|
||
|
" nn.ConvTranspose2d(in_channels=hidden_size, out_channels=hidden_size//2, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" nn.ConvTranspose2d(in_channels=hidden_size//2, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1),\n",
|
||
|
" nn.Sigmoid() # Sigmoid to ensure the output is between 0 and 1\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" x = self.encoder(x)\n",
|
||
|
" x = self.decoder(x)\n",
|
||
|
" return x"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:40:29.561602021Z",
|
||
|
"start_time": "2023-10-03T12:40:29.559204154Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "26f2513d92b78e1e"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Check if GPU support is available"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "91a01313b4d95274"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"cuda\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# device\n",
|
||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
|
"print(device)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:40:30.834696308Z",
|
||
|
"start_time": "2023-10-03T12:40:30.827970016Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "67006b35b75d8dff"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Define the training function"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "8eebf70cb27640d5"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Define the training function\n",
|
||
|
"def train(model, train_loader, optimizer, criterion, epoch, verbose=True):\n",
|
||
|
" model.train()\n",
|
||
|
" train_loss = 0\n",
|
||
|
" for i, (data, _) in enumerate(train_loader):\n",
|
||
|
" # check the type of autoencoder and modify the input data accordingly\n",
|
||
|
" if model.type == 'mlp':\n",
|
||
|
" data = data.view(data.size(0), -1)\n",
|
||
|
" data = data.to(device)\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" output = model(data)\n",
|
||
|
" loss = criterion(output, data)\n",
|
||
|
" loss.backward()\n",
|
||
|
" train_loss += loss.item()\n",
|
||
|
" optimizer.step()\n",
|
||
|
" train_loss /= len(train_loader.dataset)\n",
|
||
|
" if verbose:\n",
|
||
|
" print(f'{model.type}====> Epoch: {epoch} Average loss: {train_loss:.4f}') \n",
|
||
|
" return train_loss"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:40:31.675127463Z",
|
||
|
"start_time": "2023-10-03T12:40:31.655370005Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "5f96f7be13984747"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"The evaluation functions for the linear classification and clustering tasks"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "5f6386edcab6b1e4"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Extract encoded representations for a given loader\n",
|
||
|
"def extract_features(loader, model):\n",
|
||
|
" features = []\n",
|
||
|
" labels = []\n",
|
||
|
" model.eval()\n",
|
||
|
" with torch.no_grad():\n",
|
||
|
" for data in loader:\n",
|
||
|
" img, label = data\n",
|
||
|
" if model.type == 'mlp':\n",
|
||
|
" img = img.view(img.size(0), -1)\n",
|
||
|
" img = img.to(device)\n",
|
||
|
" feature = model.encoder(img)\n",
|
||
|
" if model.type == 'cnn':\n",
|
||
|
" feature = feature.view(feature.size(0), -1) # Flatten the CNN encoded features\n",
|
||
|
" features.append(feature)\n",
|
||
|
" labels.append(label)\n",
|
||
|
" return torch.cat(features), torch.cat(labels)\n",
|
||
|
"\n",
|
||
|
"# Define the loss test function\n",
|
||
|
"def test_loss(model, test_loader, criterion):\n",
|
||
|
" model.eval()\n",
|
||
|
" eval_loss = 0\n",
|
||
|
" with torch.no_grad():\n",
|
||
|
" for i, (data, _) in enumerate(test_loader):\n",
|
||
|
" # check the type of autoencoder and modify the input data accordingly\n",
|
||
|
" if model.type == 'mlp':\n",
|
||
|
" data = data.view(data.size(0), -1)\n",
|
||
|
" data = data.to(device)\n",
|
||
|
" output = model(data)\n",
|
||
|
" eval_loss += criterion(output, data).item()\n",
|
||
|
" eval_loss /= len(test_loader.dataset)\n",
|
||
|
" print('====> Test set loss: {:.4f}'.format(eval_loss))\n",
|
||
|
" return eval_loss\n",
|
||
|
"\n",
|
||
|
"# Define the linear classification test function\n",
|
||
|
"def test_linear(encoded_train, train_labels, encoded_test, test_labels):\n",
|
||
|
" train_features_np = encoded_train.cpu().numpy()\n",
|
||
|
" train_labels_np = train_labels.cpu().numpy()\n",
|
||
|
" test_features_np = encoded_test.cpu().numpy()\n",
|
||
|
" test_labels_np = test_labels.cpu().numpy()\n",
|
||
|
" \n",
|
||
|
" # Apply logistic regression on train features and labels\n",
|
||
|
" logistic_regression = LogisticRegression(random_state=0, max_iter=100).fit(train_features_np, train_labels_np)\n",
|
||
|
" print(f\"Train accuracy: {logistic_regression.score(train_features_np, train_labels_np)}\")\n",
|
||
|
" # Apply logistic regression on test features and labels\n",
|
||
|
" test_accuracy = logistic_regression.score(test_features_np, test_labels_np)\n",
|
||
|
" print(f\"Test accuracy: {test_accuracy}\")\n",
|
||
|
" return test_accuracy\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def test_clustering(encoded_features, true_labels):\n",
|
||
|
" encoded_features_np = encoded_features.cpu().numpy()\n",
|
||
|
" true_labels_np = true_labels.cpu().numpy()\n",
|
||
|
" \n",
|
||
|
" # Apply k-means clustering\n",
|
||
|
" kmeans = KMeans(n_clusters=10, n_init=10, random_state=0).fit(encoded_features_np)\n",
|
||
|
" cluster_labels = kmeans.labels_\n",
|
||
|
" \n",
|
||
|
" # Evaluate clustering results using Adjusted Rand Index\n",
|
||
|
" ari_score = adjusted_rand_score(true_labels_np, cluster_labels)\n",
|
||
|
" print(f\"Clustering ARI score: {ari_score}\")\n",
|
||
|
" return ari_score\n",
|
||
|
"\n",
|
||
|
"def knn_classifier(encoded_train, train_labels, encoded_test, test_labels, k=5):\n",
|
||
|
" encoded_train_np = encoded_train.cpu().numpy()\n",
|
||
|
" encoded_test_np = encoded_test.cpu().numpy()\n",
|
||
|
" train_labels_np = train_labels.cpu().numpy()\n",
|
||
|
" test_labels_np = test_labels.cpu().numpy()\n",
|
||
|
" \n",
|
||
|
" # Apply k-nearest neighbors classification\n",
|
||
|
" knn = KNeighborsClassifier(n_neighbors=k).fit(encoded_train_np, train_labels_np)\n",
|
||
|
" accuracy_score = knn.score(encoded_test_np, test_labels_np)\n",
|
||
|
" print(f\"KNN accuracy: {accuracy_score}\")\n",
|
||
|
" return accuracy_score\n",
|
||
|
"\n",
|
||
|
"def test(model, train_loader, test_loader, criterion):\n",
|
||
|
" # Extract features once for all tests\n",
|
||
|
" encoded_train, train_labels = extract_features(train_loader, model)\n",
|
||
|
" encoded_test, test_labels = extract_features(test_loader, model)\n",
|
||
|
" print(f\"{model.type} Autoencoder\")\n",
|
||
|
" results = {\n",
|
||
|
" 'reconstruction_loss': test_loss(model, test_loader, criterion),\n",
|
||
|
" 'linear_classification_accuracy': test_linear(encoded_train, train_labels, encoded_test, test_labels),\n",
|
||
|
" 'knn_classification_accuracy': knn_classifier(encoded_train, train_labels, encoded_test, test_labels),\n",
|
||
|
" 'clustering_ari_score': test_clustering(encoded_test, test_labels)\n",
|
||
|
" }\n",
|
||
|
" \n",
|
||
|
" # Save results to a log file\n",
|
||
|
" with open(\"evaluation_results.log\", \"w\") as log_file:\n",
|
||
|
" for key, value in results.items():\n",
|
||
|
" log_file.write(f\"{key}: {value}\")\n",
|
||
|
" \n",
|
||
|
" return results\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:40:32.662180572Z",
|
||
|
"start_time": "2023-10-03T12:40:32.657785583Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "b2c4483492fdd427"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"MLP AE parameters: 201616\n",
|
||
|
"CNN AE parameters: 148865\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Define the training parameters for the fully connected MLP Autoencoder\n",
|
||
|
"batch_size = 32\n",
|
||
|
"epochs = 5\n",
|
||
|
"hidden_size = 128\n",
|
||
|
"train_frequency = epochs\n",
|
||
|
"test_frequency = epochs\n",
|
||
|
"\n",
|
||
|
"# Create the fully connected MLP Autoencoder\n",
|
||
|
"input_size = trainset.data.shape[1] * trainset.data.shape[2]\n",
|
||
|
"ae = Autoencoder(input_size, hidden_size, type='mlp').to(device)\n",
|
||
|
"input_size=1\n",
|
||
|
"cnn_ae = Autoencoder(input_size, hidden_size, type='cnn').to(device)\n",
|
||
|
"# print the models' number of parameters\n",
|
||
|
"print(f\"MLP AE parameters: {sum(p.numel() for p in ae.parameters())}\")\n",
|
||
|
"print(f\"CNN AE parameters: {sum(p.numel() for p in cnn_ae.parameters())}\")\n",
|
||
|
"\n",
|
||
|
"# Define the loss function and optimizer\n",
|
||
|
"criterion = nn.MSELoss()\n",
|
||
|
"optimizer = optim.Adam(ae.parameters(), lr=1e-3)\n",
|
||
|
"optimizer_cnn = optim.Adam(cnn_ae.parameters(), lr=1e-3)\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:40:33.593858793Z",
|
||
|
"start_time": "2023-10-03T12:40:33.153759806Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "bcb22bc5af9fb014"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Create the train and test dataloaders\n",
|
||
|
"train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
|
||
|
"test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:40:34.157176359Z",
|
||
|
"start_time": "2023-10-03T12:40:34.153720678Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "f1626ce45bb25883"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
" 80%|████████ | 4/5 [01:02<00:15, 15.79s/it]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"mlp====> Epoch: 5 Average loss: 0.0598\n",
|
||
|
"cnn====> Epoch: 5 Average loss: 0.0260\n",
|
||
|
"mlp Autoencoder\n",
|
||
|
"====> Test set loss: 0.0598\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
||
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
||
|
"\n",
|
||
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
||
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
||
|
"Please also refer to the documentation for alternative solver options:\n",
|
||
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
||
|
" n_iter_i = _check_optimize_result(\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Train accuracy: 0.25626666666666664\n",
|
||
|
"Test accuracy: 0.2649\n",
|
||
|
"KNN accuracy: 0.2295\n",
|
||
|
"Clustering ARI score: 0.0614873771495409\n",
|
||
|
"cnn Autoencoder\n",
|
||
|
"====> Test set loss: 0.0260\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
||
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
||
|
"\n",
|
||
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
||
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
||
|
"Please also refer to the documentation for alternative solver options:\n",
|
||
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
||
|
" n_iter_i = _check_optimize_result(\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Train accuracy: 0.9316166666666666\n",
|
||
|
"Test accuracy: 0.9278\n",
|
||
|
"KNN accuracy: 0.9639\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"100%|██████████| 5/5 [06:56<00:00, 83.22s/it] "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Clustering ARI score: 0.3909294873624941\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"test_mlp = []\n",
|
||
|
"test_cnn = []\n",
|
||
|
"# Train the model\n",
|
||
|
"for epoch in tqdm(range(1, epochs + 1)):\n",
|
||
|
" verbose = True if epoch % train_frequency == 0 else False\n",
|
||
|
" train(ae, train_loader, optimizer, criterion, epoch, verbose)\n",
|
||
|
" train(cnn_ae, train_loader, optimizer_cnn, criterion, epoch, verbose)\n",
|
||
|
"\n",
|
||
|
" # test every n epochs\n",
|
||
|
" if epoch % test_frequency == 0:\n",
|
||
|
" restults_dic = test(ae, train_loader, test_loader, criterion)\n",
|
||
|
" test_mlp.append([restults_dic['reconstruction_loss'], restults_dic['linear_classification_accuracy'], restults_dic['knn_classification_accuracy'], restults_dic['clustering_ari_score']])\n",
|
||
|
" restults_dic = test(cnn_ae, train_loader, test_loader, criterion)\n",
|
||
|
" test_cnn.append([restults_dic['reconstruction_loss'], restults_dic['linear_classification_accuracy'], restults_dic['knn_classification_accuracy'], restults_dic['clustering_ari_score']])"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:47:30.814501313Z",
|
||
|
"start_time": "2023-10-03T12:40:34.720164326Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "7472159fdc5f2532"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Compare the evaluation results of the MLP and CNN Autoencoders"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "10639256e342a159"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Model Reconstruction Loss Linear Accuracy KNN Accuracy Clustering ARI \n",
|
||
|
"MLP AE 0.0598 0.2649 0.2295 0.0615 \n",
|
||
|
"CNN AE 0.0260 0.9278 0.9639 0.3909 \n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(f\"{'Model':<10} {'Reconstruction Loss':<20} {'Linear Accuracy':<20} {'KNN Accuracy':<20} {'Clustering ARI':<20}\")\n",
|
||
|
"print(f\"{'MLP AE':<10} {test_mlp[-1][0]:<20.4f} {test_mlp[-1][1]:<20.4f} {test_mlp[-1][2]:<20.4f} {test_mlp[-1][3]:<20.4f}\")\n",
|
||
|
"print(f\"{'CNN AE':<10} {test_cnn[-1][0]:<20.4f} {test_cnn[-1][1]:<20.4f} {test_cnn[-1][2]:<20.4f} {test_cnn[-1][3]:<20.4f}\")"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:47:30.828062767Z",
|
||
|
"start_time": "2023-10-03T12:47:30.812448850Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "50bb4c3c58af09ee"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Develop a linear classifier with fully connected layers"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "b9201d1403781706"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Define the fully connected classifier for MNIST\n",
|
||
|
"class DenseClassifier(nn.Module):\n",
|
||
|
" def __init__(self, input_size=784, hidden_size=500, num_classes=10):\n",
|
||
|
" super(DenseClassifier, self).__init__()\n",
|
||
|
" self.type = 'mlp'\n",
|
||
|
" self.encoder = nn.Sequential(\n",
|
||
|
" nn.Linear(input_size, hidden_size),\n",
|
||
|
" nn.ReLU(True))\n",
|
||
|
" self.fc1 = nn.Linear(hidden_size, num_classes)\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" x = x.view(x.size(0), -1) # Flatten the input tensor\n",
|
||
|
" x = self.encoder(x)\n",
|
||
|
" x = self.fc1(x)\n",
|
||
|
" return x\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:47:30.833296890Z",
|
||
|
"start_time": "2023-10-03T12:47:30.819270525Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "1612800950703181"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Develop a non-linear classifier with convolutional layers"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "35db4190e9c7f716"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# cnn classifier\n",
|
||
|
"class CNNClassifier(nn.Module):\n",
|
||
|
" def __init__(self, input_size=3, hidden_size=128, num_classes=10):\n",
|
||
|
" super(CNNClassifier, self).__init__()\n",
|
||
|
" self.type = 'cnn'\n",
|
||
|
" # Encoder (Feature extractor)\n",
|
||
|
" self.encoder = nn.Sequential(\n",
|
||
|
" nn.Conv2d(in_channels=input_size, out_channels=hidden_size//2, kernel_size=3, stride=2, padding=1),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" nn.Conv2d(in_channels=hidden_size//2, out_channels=hidden_size, kernel_size=3, stride=2, padding=1),\n",
|
||
|
" nn.ReLU()\n",
|
||
|
" )\n",
|
||
|
" \n",
|
||
|
" # Classifier\n",
|
||
|
" # Here, for the sake of example, I'm assuming the spatial size of the encoder output \n",
|
||
|
" # is 7x7 for an input size of 28x28. You might want to adjust this if the spatial dimensions change.\n",
|
||
|
" self.classifier = nn.Sequential(\n",
|
||
|
" nn.Flatten(),\n",
|
||
|
" nn.Linear(hidden_size*7*7, hidden_size),\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" nn.Linear(hidden_size, num_classes),\n",
|
||
|
" nn.LogSoftmax(dim=1) # LogSoftmax is typically used with NLLLoss\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" x = self.encoder(x)\n",
|
||
|
" x = self.classifier(x)\n",
|
||
|
" return x\n",
|
||
|
" return x"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:47:30.916320521Z",
|
||
|
"start_time": "2023-10-03T12:47:30.828281294Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "cb2dfcf75113fd0b"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Train and test functions for the non-linear classifier"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "7c3cf2371479da0c"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Train for the classifier\n",
|
||
|
"def train_classifier(model, train_loader, optimizer, criterion, epoch, verbose=True):\n",
|
||
|
" model.train()\n",
|
||
|
" train_loss = 0\n",
|
||
|
" correct = 0 \n",
|
||
|
" for i, (data, target) in enumerate(train_loader):\n",
|
||
|
" if model.type == 'cnn':\n",
|
||
|
" data = data.to(device)\n",
|
||
|
" else:\n",
|
||
|
" data = data.view(data.size(0), -1)\n",
|
||
|
" data = data.to(device)\n",
|
||
|
" target = target.to(device)\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" output = model(data)\n",
|
||
|
" loss = criterion(output, target)\n",
|
||
|
" loss.backward()\n",
|
||
|
" train_loss += loss.item()\n",
|
||
|
" optimizer.step()\n",
|
||
|
" # Calculate correct predictions for training accuracy\n",
|
||
|
" pred = output.argmax(dim=1, keepdim=True)\n",
|
||
|
" correct += pred.eq(target.view_as(pred)).sum().item()\n",
|
||
|
"\n",
|
||
|
" train_loss /= len(train_loader.dataset)\n",
|
||
|
" train_accuracy = 100. * correct / len(train_loader.dataset)\n",
|
||
|
" if verbose:\n",
|
||
|
" print(f'{model.type}====> Epoch: {epoch} Average loss: {train_loss:.4f}')\n",
|
||
|
" print(f'{model.type}====> Epoch: {epoch} Training accuracy: {train_accuracy:.2f}%')\n",
|
||
|
" return train_loss\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def test_classifier(model, test_loader, criterion):\n",
|
||
|
" model.eval()\n",
|
||
|
" eval_loss = 0\n",
|
||
|
" correct = 0\n",
|
||
|
" with torch.no_grad():\n",
|
||
|
" for i, (data, target) in enumerate(test_loader):\n",
|
||
|
" if model.type == 'cnn':\n",
|
||
|
" data = data.to(device)\n",
|
||
|
" else:\n",
|
||
|
" data = data.view(data.size(0), -1)\n",
|
||
|
" data = data.to(device)\n",
|
||
|
" target = target.to(device)\n",
|
||
|
" output = model(data)\n",
|
||
|
" eval_loss += criterion(output, target).item()\n",
|
||
|
" pred = output.argmax(dim=1, keepdim=True)\n",
|
||
|
" correct += pred.eq(target.view_as(pred)).sum().item()\n",
|
||
|
" eval_loss /= len(test_loader.dataset)\n",
|
||
|
" print('====> Test set loss: {:.4f}'.format(eval_loss))\n",
|
||
|
" accuracy = correct / len(test_loader.dataset)\n",
|
||
|
" print('====> Test set accuracy: {:.4f}'.format(accuracy))\n",
|
||
|
" return accuracy\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:47:30.924063759Z",
|
||
|
"start_time": "2023-10-03T12:47:30.875486950Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "ac980d25bd8a3dd3"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Define the training parameters for the fully connected classifier\n",
|
||
|
"batch_size = 32\n",
|
||
|
"epochs = 5\n",
|
||
|
"learning_rate = 1e-3\n",
|
||
|
"hidden_size = 128\n",
|
||
|
"num_classes = 10\n",
|
||
|
"train_frequency = epochs\n",
|
||
|
"test_frequency = epochs\n",
|
||
|
"# Create the fully connected classifier\n",
|
||
|
"input_size = trainset.data.shape[1] * trainset.data.shape[2]\n",
|
||
|
"classifier = DenseClassifier(input_size, hidden_size, num_classes).to(device)\n",
|
||
|
"input_size = 1\n",
|
||
|
"cnn_classifier = CNNClassifier(input_size, hidden_size, num_classes).to(device)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:47:30.924544974Z",
|
||
|
"start_time": "2023-10-03T12:47:30.875705556Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "dff05e622dcfd774"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Define the loss function and optimizer\n",
|
||
|
"criterion = nn.CrossEntropyLoss()\n",
|
||
|
"optimizer = optim.Adam(classifier.parameters(), lr=learning_rate)\n",
|
||
|
"optimizer_cnn = optim.Adam(cnn_classifier.parameters(), lr=learning_rate)\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:47:30.924773429Z",
|
||
|
"start_time": "2023-10-03T12:47:30.875787620Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "3104345cdee0eb00"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Train the non-linear classifiers"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "e1fed39be2f04745"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 18,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
" 80%|████████ | 4/5 [02:10<00:32, 32.47s/it]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"mlp====> Epoch: 5 Average loss: 0.0030\n",
|
||
|
"mlp====> Epoch: 5 Training accuracy: 97.08%\n",
|
||
|
"cnn====> Epoch: 5 Average loss: 0.0005\n",
|
||
|
"cnn====> Epoch: 5 Training accuracy: 99.49%\n",
|
||
|
"====> Test set loss: 0.0035\n",
|
||
|
"====> Test set accuracy: 0.9686\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"100%|██████████| 5/5 [02:48<00:00, 33.62s/it]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"====> Test set loss: 0.0013\n",
|
||
|
"====> Test set accuracy: 0.9880\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Train the classifier\n",
|
||
|
"for epoch in tqdm(range(1, epochs + 1)):\n",
|
||
|
" verbose = True if epoch % train_frequency == 0 else False\n",
|
||
|
" train_classifier(classifier, train_loader, optimizer, criterion, epoch, verbose)\n",
|
||
|
" train_classifier(cnn_classifier, train_loader, optimizer_cnn, criterion, epoch, verbose)\n",
|
||
|
"\n",
|
||
|
" # test every n epochs\n",
|
||
|
" if epoch % test_frequency == 0:\n",
|
||
|
" test_acc = test_classifier(classifier, test_loader, criterion)\n",
|
||
|
" test_acc_cnn = test_classifier(cnn_classifier, test_loader, criterion)\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:50:19.005163249Z",
|
||
|
"start_time": "2023-10-03T12:47:30.875867176Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "abc0c6ce338d40d9"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Load the encoder weights into the classifier"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "a06038f113d8434f"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 19,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": "<All keys matched successfully>"
|
||
|
},
|
||
|
"execution_count": 19,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# initialize the classifier with the encoder weights\n",
|
||
|
"classifier.encoder.load_state_dict(ae.encoder.state_dict())\n",
|
||
|
"cnn_classifier.encoder.load_state_dict(cnn_ae.encoder.state_dict())"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:50:19.005816667Z",
|
||
|
"start_time": "2023-10-03T12:50:18.994175691Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "6a91d8894b70ef7c"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Transfer learning"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "aafa4a9ba7208647"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
" 80%|████████ | 4/5 [01:56<00:27, 27.00s/it]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"mlp====> Epoch: 5 Average loss: 0.0547\n",
|
||
|
"mlp====> Epoch: 5 Training accuracy: 38.00%\n",
|
||
|
"cnn====> Epoch: 5 Average loss: 0.0004\n",
|
||
|
"cnn====> Epoch: 5 Training accuracy: 99.57%\n",
|
||
|
"====> Test set loss: 0.0526\n",
|
||
|
"====> Test set accuracy: 0.4150\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"100%|██████████| 5/5 [02:25<00:00, 29.01s/it]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"====> Test set loss: 0.0017\n",
|
||
|
"====> Test set accuracy: 0.9868\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# fine-tune the classifier\n",
|
||
|
"learning_rate = 1e-5\n",
|
||
|
"epochs = 5\n",
|
||
|
"train_frequency = epochs\n",
|
||
|
"test_frequency = epochs\n",
|
||
|
"optimizer_pretrained = optim.Adam(classifier.parameters(), lr=learning_rate)\n",
|
||
|
"optimizer_pretrained_cnn = optim.Adam(cnn_classifier.parameters(), lr=learning_rate)\n",
|
||
|
"for epoch in tqdm(range(1, epochs + 1)):\n",
|
||
|
" verbose = True if epoch % train_frequency == 0 else False\n",
|
||
|
" train_loss = train_classifier(classifier, train_loader, optimizer_pretrained, criterion, epoch, verbose)\n",
|
||
|
" train_loss_cnn = train_classifier(cnn_classifier, train_loader, optimizer_cnn, criterion, epoch, verbose)\n",
|
||
|
"\n",
|
||
|
" # test every n epochs\n",
|
||
|
" if epoch % test_frequency == 0:\n",
|
||
|
" test_acc_pretrained = test_classifier(classifier, test_loader, criterion)\n",
|
||
|
" test_acc_pretrained_cnn = test_classifier(cnn_classifier, test_loader, criterion)\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:52:44.085979647Z",
|
||
|
"start_time": "2023-10-03T12:50:19.003728990Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "a60dd68f988a8249"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Compare the results of the linear probing with the results of the linear classifier"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "31577275b833707a"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Model Linear Accuracy Non-linear accuracy Pretrained accuracy \n",
|
||
|
"MLP AE 0.2649 0.9686 0.4150 \n",
|
||
|
"CNN AE 0.9278 0.9880 0.9868 \n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# print a table of the accuracies. compare the results with the results of the linear probing\n",
|
||
|
"print(f\"{'Model':<10} {'Linear Accuracy':<20} {'Non-linear accuracy':<20} {'Pretrained accuracy':<20}\")\n",
|
||
|
"print(f\"{'MLP AE':<10} {test_mlp[-1][1]:<20.4f} {test_acc:<20.4f} {test_acc_pretrained:<20.4f}\")\n",
|
||
|
"print(f\"{'CNN AE':<10} {test_cnn[-1][1]:<20.4f} {test_acc_cnn:<20.4f} {test_acc_pretrained_cnn:<20.4f}\")"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:52:44.091206610Z",
|
||
|
"start_time": "2023-10-03T12:52:44.084572508Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "40d0e7f3f13404c9"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"outputs": [],
|
||
|
"source": [],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "f38a1ab6951a694e"
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 2
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython2",
|
||
|
"version": "2.7.6"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|