1623 lines
121 KiB
Plaintext
1623 lines
121 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"| Credentials | |\n",
|
||
|
"|----|----------------------------------|\n",
|
||
|
"|Host | Montanuniversitaet Leoben |\n",
|
||
|
"|Web | https://cps.unileoben.ac.at |\n",
|
||
|
"|Mail | cps@unileoben.ac.at |\n",
|
||
|
"|Author | Fotios Lygerakis |\n",
|
||
|
"|Corresponding Authors | fotios.lygerakis@unileoben.ac.at |\n",
|
||
|
"|Last edited | 28.09.2023 |"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "25b3e26209aea07e"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "1490260facaff836",
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"source": [
|
||
|
"In the first part of this tutorial we will build a fully connected MLP Autoencoder on the CIFAR10 dataset. Then we will perform linear probing on the encoder features to see how well they perform on a linear classification task."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "fc18830bb6f8d534",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:41:16.496631884Z",
|
||
|
"start_time": "2023-10-03T12:41:14.795432603Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import torch\n",
|
||
|
"import torch.nn as nn\n",
|
||
|
"import torch.optim as optim\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"from torchvision import datasets, transforms\n",
|
||
|
"from sklearn.neighbors import KNeighborsClassifier\n",
|
||
|
"from sklearn.metrics import adjusted_rand_score\n",
|
||
|
"from sklearn.linear_model import LogisticRegression\n",
|
||
|
"from sklearn.cluster import KMeans\n",
|
||
|
"from tqdm import tqdm"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "1bad4bd03deb5b7e",
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"source": [
|
||
|
"Set random seed"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "27dd48e60ae7dd9e",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:41:16.504621133Z",
|
||
|
"start_time": "2023-10-03T12:41:16.497289235Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": "<torch._C.Generator at 0x7f098e1490d0>"
|
||
|
},
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Set random seed\n",
|
||
|
"torch.manual_seed(0)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "cc7f167a33227e94",
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"source": [
|
||
|
"Load the CIFAR10 dataset"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "34248e8bc2678fd3",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:41:18.687951160Z",
|
||
|
"start_time": "2023-10-03T12:41:17.491688103Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Files already downloaded and verified\n",
|
||
|
"Files already downloaded and verified\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Define the transformations\n",
|
||
|
"transform = transforms.Compose([transforms.ToTensor(),\n",
|
||
|
" transforms.Normalize((0.5,), (0.5,))])\n",
|
||
|
"# Download and load the training data\n",
|
||
|
"trainset = datasets.CIFAR10('data', download=True, train=True, transform=transform)\n",
|
||
|
"# Download and load the test data\n",
|
||
|
"testset = datasets.CIFAR10('data', download=True, train=False, transform=transform)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "928dfac955d0d778",
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"source": [
|
||
|
"Print some examples from the dataset"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "87c6eae807f51118",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:41:19.826292939Z",
|
||
|
"start_time": "2023-10-03T12:41:19.492130184Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": "<Figure size 1000x400 with 10 Axes>",
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA8cAAAGJCAYAAACnwkFvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAADLAUlEQVR4nOz9eZhdVZ3vj3/2medzaq5KZahKZZ4IJkxhSEAwMimoDNo2RBBoWy8OF1q5rQI/7XZEUbQFtUUUtZVBLo3K1IIiIDOBEDJXhqqk5jp15mGfvb5/+Etdi/cHOIypynm/nqeflnf2sPaa9tpVWa9YxhgjhBBCCCGEEEJIDeM60AUghBBCCCGEEEIONPw4JoQQQgghhBBS8/DjmBBCCCGEEEJIzcOPY0IIIYQQQgghNQ8/jgkhhBBCCCGE1Dz8OCaEEEIIIYQQUvPw45gQQgghhBBCSM3Dj2NCCCGEEEIIITUPP44JIYQQQgghhNQ8/Dj+/7Nz506xLEu++c1vvmnXfPDBB8WyLHnwwQdf9zV+/vOfy4IFC8Tr9UoikXjTykbI3zNZ+/+BZs2aNbJkyZIDXQxyAOCY0OGYIFNpbKxZs0bWrFnzpl6TTH6mUh+tBV5Le1x11VViWdaErKOjQ9atW/cWlQ6Z0h/HP/3pT8WyLHnyyScPdFHeEjZt2iTr1q2Trq4u+dGPfiQ//OEPD3SRyCTiYO//v/zlL+Xaa6890MUgUwiOCUJ0DvaxQaY+7KOvjVwuJ1dddRU/1t8CpvTH8cHOgw8+KI7jyHe+8x1Zt26dnH322Qe6SIS8bfBDgJCJcEwQQggR+dvH8dVXX33QfRx//vOfl3w+f0DLwI/jSczAwICIyKv+dWpjzAHvSIQcSAqFgjiOc6CLQcikgWOCkANHLpc70EUgZEri8XgkEAgc0DIc9B/HpVJJvvjFL8qKFSskHo9LOByWY489Vh544IGXPefb3/62zJo1S4LBoKxevVo2bNgAx2zatEk+8IEPSH19vQQCAVm5cqXceeedr1qeXC4nmzZtkqGhoVc8rqOjQ6688koREWlqahLLsuSqq64a/7PTTjtN7rnnHlm5cqUEg0G54YYbRERkx44dctZZZ0l9fb2EQiE58sgj5Xe/+x1cf9euXfKe97xHwuGwNDc3y6c//Wm55557uJ/iIGOq9v81a9bI7373O9m1a5dYliWWZUlHR4eI/L99P//1X/8ln//856W9vV1CoZCkUil1r4rI//vrWjt37pyQ/+EPf5DVq1dLNBqVWCwmhx12mPzyl798xbLde++9EgqF5IMf/KDYtv2qz0wmFxwTf4NjgryUqTo29vPDH/5Qurq6JBgMyuGHHy4PPfSQelyxWJQrr7xS5syZI36/X2bMmCH/8i//IsViEY69+eabZcWKFRIMBqW+vl7OPfdc2bNnz4Rj9u/Df+qpp+S4446TUCgk/+f//J+qykxeG1O5jz700ENy1llnycyZM8f73ac//Wn45dbL7ZNft27d+Jy/c+dOaWpqEhGRq6++evydsP87QUTkj3/8oxx77LESDoclkUjIe9/7XnnxxRcnXHP/+2HLli3y4Q9/WOLxuDQ1NckXvvAFMcbInj175L3vfa/EYjFpbW2Va665Bso1MDAgF154obS0tEggEJBDDjlEbrrpppeth1drj5d7Z72UZDIpn/rUp2TGjBni9/tlzpw58rWvfe1N+aGw5w1fYZKTSqXkxz/+sXzwgx+Uiy66SNLptPznf/6nrF27Vh5//HFZvnz5hON/9rOfSTqdlo9//ONSKBTkO9/5jpxwwgny/PPPS0tLi4iIvPDCC3L00UdLe3u7fO5zn5NwOCy/+c1v5IwzzpDbbrtNzjzzzJctz+OPPy7HH3+8XHnllRM68Uu59tpr5Wc/+5n89re/lR/84AcSiURk2bJl43++efNm+eAHPyiXXHKJXHTRRTJ//nzp7++XVatWSS6Xk0svvVQaGhrkpptukve85z1y6623jpcrm83KCSecIPv27ZNPfvKT0traKr/85S9fcXIhU5Op2v//9V//VcbGxqSnp0e+/e1vi4hIJBKZcMyXvvQl8fl8ctlll0mxWBSfz/ea6uanP/2pXHDBBbJ48WK54oorJJFIyDPPPCN33323fOhDH1LPueuuu+QDH/iAnHPOOfKTn/xE3G73a7onOfBwTLw8HBO1zVQdGyIi//mf/ymXXHKJrFq1Sj71qU/Jjh075D3veY/U19fLjBkzxo9zHEfe8573yF/+8he5+OKLZeHChfL888/Lt7/9bdmyZYvccccd48f+27/9m3zhC1+Qs88+Wz760Y/K4OCgXHfddXLcccfJM888M+Fv9Q0PD8vJJ58s5557rnz4wx8ef37y5jKV++gtt9wiuVxOPvaxj0lDQ4M8/vjjct1110lPT4/ccsstr6kempqa5Ac/+IF87GMfkzPPPFPe9773iYiMfyfcf//9cvLJJ8vs2bPlqquuknw+L9ddd50cffTR8vTTT49/ZO/nnHPOkYULF8pXv/pV+d3vfidf/vKXpb6+Xm644QY54YQT5Gtf+5r84he/kMsuu0wOO+wwOe6440REJJ/Py5o1a2Tbtm3yiU98Qjo7O+WWW26RdevWSTKZlE9+8pMT7lNNe1RDLpeT1atXS29vr1xyySUyc+ZMeeSRR+SKK66Qffv2vfHtR2YKc+ONNxoRMU888cTLHmPbtikWixOy0dFR09LSYi644ILxrLu724iICQaDpqenZzx/7LHHjIiYT3/60+PZO9/5TrN06VJTKBTGM8dxzKpVq8zcuXPHswceeMCIiHnggQcgu/LKK1/1+a688kojImZwcHBCPmvWLCMi5u67756Qf+pTnzIiYh566KHxLJ1Om87OTtPR0WEqlYoxxphrrrnGiIi54447xo/L5/NmwYIFUF4yeTnY+/+pp55qZs2aBfn+a8yePdvkcrkJf7Z/zLyU/XXV3d1tjDEmmUyaaDRqjjjiCJPP5ycc6zjO+P9evXq1Wbx4sTHGmNtuu814vV5z0UUXjY8lMrngmOCYIDoH89golUqmubnZLF++fEL5f/jDHxoRMatXrx7Pfv7znxuXyzVhnWSMMddff70REfPwww8bY4zZuXOncbvd5t/+7d8mHPf8888bj8czIV+9erUREXP99de/YjnJK3Mw91FjDMzNxhjzla98xViWZXbt2jWerV69ekKf3c/5558/Yf4fHBx82XsvX77cNDc3m+Hh4fFs/fr1xuVymfPOO2882/9+uPjii8cz27bN9OnTjWVZ5qtf/ep4Pjo6aoLBoDn//PPHs2uvvdaIiLn55pvHs1KpZI466igTiURMKpUyxry29tDeWbNmzZpw3y996UsmHA6bLVu2TDjuc5/7nHG73Wb37t1QJ6+Fg/6vVbvd7vGfnjuOIyMjI2LbtqxcuVKefvppOP6MM86Q9vb28f8+/PDD5YgjjpDf//73IiIyMjIif/zjH+Xss8+WdDotQ0NDMjQ0JMPDw7J27VrZunWr9Pb2vmx51qxZI8aYV/0J06vR2dkpa9eunZD9/ve/l8MPP1yOOeaY8SwSicjFF18sO3fulI0bN4qIyN133y3t7e3ynve8Z/y4QCAgF1100RsqE5l8HKz9X0Tk/PPPl2Aw+LrOve+++ySdTsvnPvc52Nui/XWeX/3qV3LOOefIJZdcIjfccIO4XAf91HnQwjGhwzFBpurYePLJJ2VgYED+6Z/+acLflli3bp3E4/EJx95yyy2ycOFCWbBgwXh5hoaG5IQTThARGf8bdLfffrs4jiNnn332hONaW1tl7ty58Dft/H6/fOQjH3nFcpI3zlTtoyIyYW7OZrMyNDQkq1atEmOMPPPMM9VWwauyb98+efbZZ2XdunVSX18/ni9btkxOOumk8Wf/ez760Y+O/2+32y0rV64UY4xceOGF43kikZD58+fLjh07xrPf//730traKh/84AfHM6/XK5deeqlkMhn505/+NOE+r9Ye1XLLLbfIscceK3V1dRPG54knniiVSkX+/Oc/v6brvZSD/q9Vi4jcdNNNcs0118i
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Get the first 10 samples from CIFAR10\n",
|
||
|
"dataiter = iter(trainset)\n",
|
||
|
"images, labels = [], []\n",
|
||
|
"for i in range(10):\n",
|
||
|
" image, label = next(dataiter)\n",
|
||
|
" images.append(image)\n",
|
||
|
" labels.append(label)\n",
|
||
|
"\n",
|
||
|
"# CIFAR10 label names\n",
|
||
|
"cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n",
|
||
|
"\n",
|
||
|
"# Plot the CIFAR10 samples\n",
|
||
|
"fig, axes = plt.subplots(2, 5, figsize=(10, 4))\n",
|
||
|
"for ax, img, lbl in zip(axes.ravel(), images, labels):\n",
|
||
|
" ax.imshow(img.permute(1, 2, 0).numpy() * 0.5 + 0.5) # denormalize\n",
|
||
|
" ax.set_title(f'Label: {cifar10_classes[lbl]}')\n",
|
||
|
" ax.axis('off')\n",
|
||
|
"plt.tight_layout()\n",
|
||
|
"plt.show()\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "e4e25962ef8e5b0d",
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"source": [
|
||
|
"Define the MLP and Convolutional Autoencoder"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"id": "26f2513d92b78e1e",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:41:21.511063666Z",
|
||
|
"start_time": "2023-10-03T12:41:21.485979560Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"\n",
|
||
|
"class Autoencoder(nn.Module):\n",
|
||
|
" def __init__(self, input_size, hidden_size, type='mlp'):\n",
|
||
|
" super(Autoencoder, self).__init__()\n",
|
||
|
" self.type = type\n",
|
||
|
" if self.type == 'mlp':\n",
|
||
|
" self.encoder = nn.Sequential(\n",
|
||
|
" nn.Linear(input_size, hidden_size),\n",
|
||
|
" nn.ReLU(True))\n",
|
||
|
" self.decoder = nn.Sequential(\n",
|
||
|
" nn.Linear(hidden_size, input_size),\n",
|
||
|
" nn.ReLU(True),\n",
|
||
|
" nn.Sigmoid()\n",
|
||
|
" )\n",
|
||
|
" elif self.type == 'cnn':\n",
|
||
|
" # Encoder module for CIFAR10\n",
|
||
|
" self.encoder = nn.Sequential(\n",
|
||
|
" nn.Conv2d(input_size, 16, 3, stride=2, padding=1), # 16x16x16\n",
|
||
|
" nn.ReLU(True),\n",
|
||
|
" nn.Conv2d(16, 32, 3, stride=2, padding=1), # 8x8x32\n",
|
||
|
" nn.ReLU(True),\n",
|
||
|
" nn.Conv2d(32, 64, 3, stride=2, padding=1), # 4x4x64\n",
|
||
|
" nn.ReLU(True)\n",
|
||
|
" )\n",
|
||
|
" # Decoder module for CIFAR10\n",
|
||
|
" self.decoder = nn.Sequential(\n",
|
||
|
" nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), # 8x8x32\n",
|
||
|
" nn.ReLU(True),\n",
|
||
|
" nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1), # 16x16x16\n",
|
||
|
" nn.ReLU(True),\n",
|
||
|
" nn.ConvTranspose2d(16, input_size, 3, stride=2, padding=1, output_padding=1), # 32x32x3\n",
|
||
|
" nn.Sigmoid()\n",
|
||
|
" )\n",
|
||
|
" else:\n",
|
||
|
" raise ValueError(f\"Unknown Autoencoder type: {type}\")\n",
|
||
|
" \n",
|
||
|
" def forward(self, x):\n",
|
||
|
" x = self.encoder(x)\n",
|
||
|
" x = self.decoder(x)\n",
|
||
|
" return x\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "91a01313b4d95274",
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"source": [
|
||
|
"Check if GPU support is available"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "67006b35b75d8dff",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:41:22.884901160Z",
|
||
|
"start_time": "2023-10-03T12:41:22.860000300Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"cuda\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# device\n",
|
||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
|
"print(device)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "8eebf70cb27640d5",
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"source": [
|
||
|
"Define the training function"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"id": "5f96f7be13984747",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:41:23.954782535Z",
|
||
|
"start_time": "2023-10-03T12:41:23.943574306Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Define the training function\n",
|
||
|
"def train(model, train_loader, optimizer, criterion, epoch, verbose=True):\n",
|
||
|
" model.train()\n",
|
||
|
" train_loss = 0\n",
|
||
|
" for i, (data, _) in enumerate(train_loader):\n",
|
||
|
" # check the type of autoencoder and modify the input data accordingly\n",
|
||
|
" if model.type == 'mlp':\n",
|
||
|
" data = data.view(data.size(0), -1)\n",
|
||
|
" data = data.to(device)\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" output = model(data)\n",
|
||
|
" loss = criterion(output, data)\n",
|
||
|
" loss.backward()\n",
|
||
|
" train_loss += loss.item()\n",
|
||
|
" optimizer.step()\n",
|
||
|
" train_loss /= len(train_loader.dataset)\n",
|
||
|
" if verbose:\n",
|
||
|
" print(f'{model.type}====> Epoch: {epoch} Average loss: {train_loss:.4f}') \n",
|
||
|
" return train_loss"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "5f6386edcab6b1e4",
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"source": [
|
||
|
"The evaluation functions for the linear classification"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"id": "b2c4483492fdd427",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:41:24.972220830Z",
|
||
|
"start_time": "2023-10-03T12:41:24.960696742Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Extract encoded representations for a given loader\n",
|
||
|
"def extract_features(loader, model):\n",
|
||
|
" features = []\n",
|
||
|
" labels = []\n",
|
||
|
" model.eval()\n",
|
||
|
" with torch.no_grad():\n",
|
||
|
" for data in loader:\n",
|
||
|
" img, label = data\n",
|
||
|
" if model.type == 'mlp':\n",
|
||
|
" img = img.view(img.size(0), -1)\n",
|
||
|
" img = img.to(device)\n",
|
||
|
" feature = model.encoder(img)\n",
|
||
|
" if model.type == 'cnn':\n",
|
||
|
" feature = feature.view(feature.size(0), -1) # Flatten the CNN encoded features\n",
|
||
|
" features.append(feature)\n",
|
||
|
" labels.append(label)\n",
|
||
|
" return torch.cat(features), torch.cat(labels)\n",
|
||
|
"\n",
|
||
|
"# Define the loss test function\n",
|
||
|
"def test_loss(model, test_loader, criterion, verbose=True):\n",
|
||
|
" model.eval()\n",
|
||
|
" eval_loss = 0\n",
|
||
|
" with torch.no_grad():\n",
|
||
|
" for i, (data, _) in enumerate(test_loader):\n",
|
||
|
" # check the type of autoencoder and modify the input data accordingly\n",
|
||
|
" if model.type == 'mlp':\n",
|
||
|
" data = data.view(data.size(0), -1)\n",
|
||
|
" data = data.to(device)\n",
|
||
|
" output = model(data)\n",
|
||
|
" eval_loss += criterion(output, data).item()\n",
|
||
|
" eval_loss /= len(test_loader.dataset)\n",
|
||
|
" if verbose:\n",
|
||
|
" print('====> Test set loss: {:.4f}'.format(eval_loss))\n",
|
||
|
" return eval_loss\n",
|
||
|
"\n",
|
||
|
"# Define the linear classification test function\n",
|
||
|
"def test_linear(encoded_train, train_labels, encoded_test, test_labels):\n",
|
||
|
" train_features_np = encoded_train.cpu().numpy()\n",
|
||
|
" train_labels_np = train_labels.cpu().numpy()\n",
|
||
|
" test_features_np = encoded_test.cpu().numpy()\n",
|
||
|
" test_labels_np = test_labels.cpu().numpy()\n",
|
||
|
" \n",
|
||
|
" # Apply logistic regression on train features and labels\n",
|
||
|
" logistic_regression = LogisticRegression(random_state=0, max_iter=100).fit(train_features_np, train_labels_np)\n",
|
||
|
" print(f\"Train accuracy: {logistic_regression.score(train_features_np, train_labels_np)}\")\n",
|
||
|
" # Apply logistic regression on test features and labels\n",
|
||
|
" test_accuracy = logistic_regression.score(test_features_np, test_labels_np)\n",
|
||
|
" print(f\"Test accuracy: {test_accuracy}\")\n",
|
||
|
" return test_accuracy\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def test_clustering(encoded_features, true_labels):\n",
|
||
|
" encoded_features_np = encoded_features.cpu().numpy()\n",
|
||
|
" true_labels_np = true_labels.cpu().numpy()\n",
|
||
|
" \n",
|
||
|
" # Apply k-means clustering\n",
|
||
|
" kmeans = KMeans(n_clusters=10, n_init=10, random_state=0).fit(encoded_features_np)\n",
|
||
|
" cluster_labels = kmeans.labels_\n",
|
||
|
" \n",
|
||
|
" # Evaluate clustering results using Adjusted Rand Index\n",
|
||
|
" ari_score = adjusted_rand_score(true_labels_np, cluster_labels)\n",
|
||
|
" print(f\"Clustering ARI score: {ari_score}\")\n",
|
||
|
" return ari_score\n",
|
||
|
"\n",
|
||
|
"def knn_classifier(encoded_train, train_labels, encoded_test, test_labels, k=5):\n",
|
||
|
" encoded_train_np = encoded_train.cpu().numpy()\n",
|
||
|
" encoded_test_np = encoded_test.cpu().numpy()\n",
|
||
|
" train_labels_np = train_labels.cpu().numpy()\n",
|
||
|
" test_labels_np = test_labels.cpu().numpy()\n",
|
||
|
" \n",
|
||
|
" # Apply k-nearest neighbors classification\n",
|
||
|
" knn = KNeighborsClassifier(n_neighbors=k).fit(encoded_train_np, train_labels_np)\n",
|
||
|
" accuracy_score = knn.score(encoded_test_np, test_labels_np)\n",
|
||
|
" print(f\"KNN accuracy: {accuracy_score}\")\n",
|
||
|
" return accuracy_score\n",
|
||
|
"\n",
|
||
|
"def test(model, train_loader, test_loader, criterion):\n",
|
||
|
" # Extract features once for all tests\n",
|
||
|
" encoded_train, train_labels = extract_features(train_loader, model)\n",
|
||
|
" encoded_test, test_labels = extract_features(test_loader, model)\n",
|
||
|
" print(f\"{model.type} Autoencoder\")\n",
|
||
|
" results = {\n",
|
||
|
" 'reconstruction_loss': test_loss(model, test_loader, criterion),\n",
|
||
|
" 'linear_classification_accuracy': test_linear(encoded_train, train_labels, encoded_test, test_labels),\n",
|
||
|
" 'knn_classification_accuracy': knn_classifier(encoded_train, train_labels, encoded_test, test_labels),\n",
|
||
|
" 'clustering_ari_score': test_clustering(encoded_test, test_labels)\n",
|
||
|
" }\n",
|
||
|
" \n",
|
||
|
" # Save results to a log file\n",
|
||
|
" with open(\"evaluation_results.log\", \"w\") as log_file:\n",
|
||
|
" for key, value in results.items():\n",
|
||
|
" log_file.write(f\"{key}: {value}\")\n",
|
||
|
" \n",
|
||
|
" return results\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"id": "bcb22bc5af9fb014",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:41:26.789225198Z",
|
||
|
"start_time": "2023-10-03T12:41:26.359259721Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"MLP AE parameters: 789632\n",
|
||
|
"CNN AE parameters: 47107\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Define the training parameters for the fully connected MLP Autoencoder\n",
|
||
|
"batch_size = 32\n",
|
||
|
"epochs = 5\n",
|
||
|
"input_size = trainset.data.shape[1] * trainset.data.shape[2] * trainset.data.shape[3]\n",
|
||
|
"hidden_size = 128\n",
|
||
|
"train_frequency = epochs\n",
|
||
|
"test_frequency = epochs\n",
|
||
|
"\n",
|
||
|
"# Create the fully connected MLP Autoencoder\n",
|
||
|
"ae = Autoencoder(input_size, hidden_size, type='mlp').to(device)\n",
|
||
|
"input_size = trainset.data.shape[3]\n",
|
||
|
"cnn_ae = Autoencoder(input_size, hidden_size, type='cnn').to(device)\n",
|
||
|
"# print the models' number of parameters\n",
|
||
|
"print(f\"MLP AE parameters: {sum(p.numel() for p in ae.parameters())}\")\n",
|
||
|
"print(f\"CNN AE parameters: {sum(p.numel() for p in cnn_ae.parameters())}\")\n",
|
||
|
"\n",
|
||
|
"# Define the loss function and optimizer\n",
|
||
|
"criterion = nn.MSELoss()\n",
|
||
|
"optimizer = optim.Adam(ae.parameters(), lr=1e-3)\n",
|
||
|
"optimizer_cnn = optim.Adam(cnn_ae.parameters(), lr=1e-3)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"id": "f1626ce45bb25883",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:41:27.399274149Z",
|
||
|
"start_time": "2023-10-03T12:41:27.378668195Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Create the train and test dataloaders\n",
|
||
|
"train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
|
||
|
"test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"id": "7472159fdc5f2532",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:47:02.996829577Z",
|
||
|
"start_time": "2023-10-03T12:41:27.969843962Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
" 80%|████████ | 4/5 [01:21<00:22, 22.85s/it]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"mlp====> Epoch: 5 Average loss: 0.0174\n",
|
||
|
"cnn====> Epoch: 5 Average loss: 0.0045\n",
|
||
|
"mlp Autoencoder\n",
|
||
|
"====> Test set loss: 0.0172\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
||
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
||
|
"\n",
|
||
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
||
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
||
|
"Please also refer to the documentation for alternative solver options:\n",
|
||
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
||
|
" n_iter_i = _check_optimize_result(\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Train accuracy: 0.39372\n",
|
||
|
"Test accuracy: 0.3933\n",
|
||
|
"KNN accuracy: 0.3391\n",
|
||
|
"Clustering ARI score: 0.03965075119494781\n",
|
||
|
"cnn Autoencoder\n",
|
||
|
"====> Test set loss: 0.0044\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
||
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
||
|
"\n",
|
||
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
||
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
||
|
"Please also refer to the documentation for alternative solver options:\n",
|
||
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
||
|
" n_iter_i = _check_optimize_result(\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Train accuracy: 0.41372\n",
|
||
|
"Test accuracy: 0.3981\n",
|
||
|
"KNN accuracy: 0.3545\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"100%|██████████| 5/5 [05:34<00:00, 67.00s/it] "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Clustering ARI score: 0.051552759482472406\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"test_mlp = []\n",
|
||
|
"test_cnn = []\n",
|
||
|
"# Train the model\n",
|
||
|
"for epoch in tqdm(range(1, epochs + 1)):\n",
|
||
|
" verbose = True if epoch % train_frequency == 0 else False\n",
|
||
|
" train(ae, train_loader, optimizer, criterion, epoch, verbose)\n",
|
||
|
" train(cnn_ae, train_loader, optimizer_cnn, criterion, epoch, verbose)\n",
|
||
|
"\n",
|
||
|
" # test every n epochs\n",
|
||
|
" if epoch % test_frequency == 0:\n",
|
||
|
" restults_dic = test(ae, train_loader, test_loader, criterion)\n",
|
||
|
" test_mlp.append([restults_dic['reconstruction_loss'], restults_dic['linear_classification_accuracy'], restults_dic['knn_classification_accuracy'], restults_dic['clustering_ari_score']])\n",
|
||
|
" restults_dic = test(cnn_ae, train_loader, test_loader, criterion)\n",
|
||
|
" test_cnn.append([restults_dic['reconstruction_loss'], restults_dic['linear_classification_accuracy'], restults_dic['knn_classification_accuracy'], restults_dic['clustering_ari_score']])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "10639256e342a159",
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"source": [
|
||
|
"Compare the evaluation results of the MLP and CNN Autoencoders"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 30,
|
||
|
"id": "50bb4c3c58af09ee",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T13:06:01.418631529Z",
|
||
|
"start_time": "2023-10-03T13:06:01.405317147Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Model Reconstruction Loss Linear Accuracy KNN Accuracy Clustering ARI \n",
|
||
|
"MLP AE 0.0172 0.3933 0.3391 0.0397 \n",
|
||
|
"CNN AE 0.0044 0.3981 0.3545 0.0516 \n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(f\"{'Model':<10} {'Reconstruction Loss':<20} {'Linear Accuracy':<20} {'KNN Accuracy':<20} {'Clustering ARI':<20}\")\n",
|
||
|
"print(f\"{'MLP AE':<10} {test_mlp[-1][0]:<20.4f} {test_mlp[-1][1]:<20.4f} {test_mlp[-1][2]:<20.4f} {test_mlp[-1][3]:<20.4f}\")\n",
|
||
|
"print(f\"{'CNN AE':<10} {test_cnn[-1][0]:<20.4f} {test_cnn[-1][1]:<20.4f} {test_cnn[-1][2]:<20.4f} {test_cnn[-1][3]:<20.4f}\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "b9201d1403781706",
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"source": [
|
||
|
"Develop a linear classifier with fully connected layers"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"id": "1612800950703181",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:47:03.004530043Z",
|
||
|
"start_time": "2023-10-03T12:47:02.996038048Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Define the fully connected classifier for MNIST\n",
|
||
|
"class DenseClassifier(nn.Module):\n",
|
||
|
" def __init__(self, input_size=784, hidden_size=500, num_classes=10):\n",
|
||
|
" super(DenseClassifier, self).__init__()\n",
|
||
|
" self.type = 'mlp'\n",
|
||
|
" self.encoder = nn.Sequential(\n",
|
||
|
" nn.Linear(input_size, hidden_size),\n",
|
||
|
" nn.ReLU(True))\n",
|
||
|
" self.fc1 = nn.Linear(hidden_size, num_classes)\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" x = x.view(x.size(0), -1) # Flatten the input tensor\n",
|
||
|
" x = self.encoder(x)\n",
|
||
|
" x = self.fc1(x)\n",
|
||
|
" return x\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "35db4190e9c7f716",
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"source": [
|
||
|
"Develop a non-linear classifier with convolutional layers"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"id": "cb2dfcf75113fd0b",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:47:03.041289873Z",
|
||
|
"start_time": "2023-10-03T12:47:02.999128924Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# cnn classifier\n",
|
||
|
"class CNNClassifier(nn.Module):\n",
|
||
|
" def __init__(self, input_size, hidden_size=128, num_classes=10):\n",
|
||
|
" super(CNNClassifier, self).__init__()\n",
|
||
|
" self.type = 'cnn'\n",
|
||
|
" # Encoder (Feature extractor)\n",
|
||
|
" self.encoder = nn.Sequential(\n",
|
||
|
" nn.Conv2d(input_size, 16, 3, stride=2, padding=1), # 16x16x16\n",
|
||
|
" nn.ReLU(True),\n",
|
||
|
" nn.Conv2d(16, 32, 3, stride=2, padding=1), # 8x8x32\n",
|
||
|
" nn.ReLU(True),\n",
|
||
|
" nn.Conv2d(32, 64, 3, stride=2, padding=1), # 4x4x64\n",
|
||
|
" nn.ReLU(True)\n",
|
||
|
" )\n",
|
||
|
" \n",
|
||
|
" # Classifier\n",
|
||
|
" self.classifier = nn.Sequential(\n",
|
||
|
" nn.Flatten(), \n",
|
||
|
" nn.Linear(4*4*64, hidden_size),\n",
|
||
|
" nn.ReLU(True),\n",
|
||
|
" nn.Linear(hidden_size, num_classes)\n",
|
||
|
" )\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" x = self.encoder(x)\n",
|
||
|
" x = self.classifier(x)\n",
|
||
|
" return x\n",
|
||
|
" return x"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "7c3cf2371479da0c",
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"source": [
|
||
|
"Train and test functions for the non-linear classifier"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"id": "ac980d25bd8a3dd3",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:47:03.062646362Z",
|
||
|
"start_time": "2023-10-03T12:47:03.020366991Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Train for the classifier\n",
|
||
|
"def train_classifier(model, train_loader, optimizer, criterion, epoch, verbose=True):\n",
|
||
|
" model.train()\n",
|
||
|
" train_loss = 0\n",
|
||
|
" correct = 0 \n",
|
||
|
" for i, (data, target) in enumerate(train_loader):\n",
|
||
|
" if model.type == 'cnn':\n",
|
||
|
" data = data.to(device)\n",
|
||
|
" else:\n",
|
||
|
" data = data.view(data.size(0), -1)\n",
|
||
|
" data = data.to(device)\n",
|
||
|
" target = target.to(device)\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" output = model(data)\n",
|
||
|
" loss = criterion(output, target)\n",
|
||
|
" loss.backward()\n",
|
||
|
" train_loss += loss.item()\n",
|
||
|
" optimizer.step()\n",
|
||
|
" # Calculate correct predictions for training accuracy\n",
|
||
|
" pred = output.argmax(dim=1, keepdim=True)\n",
|
||
|
" correct += pred.eq(target.view_as(pred)).sum().item()\n",
|
||
|
"\n",
|
||
|
" train_loss /= len(train_loader.dataset)\n",
|
||
|
" train_accuracy = 100. * correct / len(train_loader.dataset)\n",
|
||
|
" if verbose:\n",
|
||
|
" print(f'{model.type}====> Epoch: {epoch} Average loss: {train_loss:.4f}')\n",
|
||
|
" print(f'{model.type}====> Epoch: {epoch} Training accuracy: {train_accuracy:.2f}%')\n",
|
||
|
" return train_loss\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def test_classifier(model, test_loader, criterion):\n",
|
||
|
" model.eval()\n",
|
||
|
" eval_loss = 0\n",
|
||
|
" correct = 0\n",
|
||
|
" with torch.no_grad():\n",
|
||
|
" for i, (data, target) in enumerate(test_loader):\n",
|
||
|
" if model.type == 'cnn':\n",
|
||
|
" data = data.to(device)\n",
|
||
|
" else:\n",
|
||
|
" data = data.view(data.size(0), -1)\n",
|
||
|
" data = data.to(device)\n",
|
||
|
" target = target.to(device)\n",
|
||
|
" output = model(data)\n",
|
||
|
" eval_loss += criterion(output, target).item()\n",
|
||
|
" pred = output.argmax(dim=1, keepdim=True)\n",
|
||
|
" correct += pred.eq(target.view_as(pred)).sum().item()\n",
|
||
|
" eval_loss /= len(test_loader.dataset)\n",
|
||
|
" print('====> Test set loss: {:.4f}'.format(eval_loss))\n",
|
||
|
" accuracy = correct / len(test_loader.dataset)\n",
|
||
|
" print('====> Test set accuracy: {:.4f}'.format(accuracy))\n",
|
||
|
" return accuracy\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"id": "dff05e622dcfd774",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:47:03.092600803Z",
|
||
|
"start_time": "2023-10-03T12:47:03.049595871Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Define the training parameters for the fully connected classifier\n",
|
||
|
"batch_size = 32\n",
|
||
|
"epochs = 5\n",
|
||
|
"learning_rate = 1e-3\n",
|
||
|
"hidden_size = 128\n",
|
||
|
"num_classes = 10\n",
|
||
|
"train_frequency = epochs\n",
|
||
|
"test_frequency = epochs\n",
|
||
|
"# Create the fully connected classifier\n",
|
||
|
"input_size = trainset.data.shape[1] * trainset.data.shape[2] * trainset.data.shape[3]\n",
|
||
|
"classifier = DenseClassifier(input_size, hidden_size, num_classes).to(device)\n",
|
||
|
"input_size = trainset.data.shape[3]\n",
|
||
|
"cnn_classifier = CNNClassifier(input_size, hidden_size, num_classes).to(device)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"id": "3104345cdee0eb00",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:47:03.095971390Z",
|
||
|
"start_time": "2023-10-03T12:47:03.080480468Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Define the loss function and optimizer\n",
|
||
|
"criterion = nn.CrossEntropyLoss()\n",
|
||
|
"optimizer = optim.Adam(classifier.parameters(), lr=learning_rate)\n",
|
||
|
"optimizer_cnn = optim.Adam(cnn_classifier.parameters(), lr=learning_rate)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "e1fed39be2f04745",
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"source": [
|
||
|
"Train the non-linear classifiers"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 18,
|
||
|
"id": "abc0c6ce338d40d9",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:49:41.536281525Z",
|
||
|
"start_time": "2023-10-03T12:47:03.094970008Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
" 80%|████████ | 4/5 [02:03<00:30, 30.65s/it]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"mlp====> Epoch: 5 Average loss: 0.0401\n",
|
||
|
"mlp====> Epoch: 5 Training accuracy: 55.44%\n",
|
||
|
"cnn====> Epoch: 5 Average loss: 0.0265\n",
|
||
|
"cnn====> Epoch: 5 Training accuracy: 69.87%\n",
|
||
|
"====> Test set loss: 0.0446\n",
|
||
|
"====> Test set accuracy: 0.5129\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"100%|██████████| 5/5 [02:38<00:00, 31.69s/it]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"====> Test set loss: 0.0319\n",
|
||
|
"====> Test set accuracy: 0.6478\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Train the classifier\n",
|
||
|
"for epoch in tqdm(range(1, epochs + 1)):\n",
|
||
|
" verbose = True if epoch % train_frequency == 0 else False\n",
|
||
|
" train_classifier(classifier, train_loader, optimizer, criterion, epoch, verbose)\n",
|
||
|
" train_classifier(cnn_classifier, train_loader, optimizer_cnn, criterion, epoch, verbose)\n",
|
||
|
"\n",
|
||
|
" # test every n epochs\n",
|
||
|
" if epoch % test_frequency == 0:\n",
|
||
|
" test_acc = test_classifier(classifier, test_loader, criterion)\n",
|
||
|
" test_acc_cnn = test_classifier(cnn_classifier, test_loader, criterion)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "a06038f113d8434f",
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"source": [
|
||
|
"Load the encoder weights into the classifier"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 19,
|
||
|
"id": "6a91d8894b70ef7c",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:49:41.571816419Z",
|
||
|
"start_time": "2023-10-03T12:49:41.535695938Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": "<All keys matched successfully>"
|
||
|
},
|
||
|
"execution_count": 19,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# initialize the classifier with the encoder weights\n",
|
||
|
"classifier.encoder.load_state_dict(ae.encoder.state_dict())\n",
|
||
|
"cnn_classifier.encoder.load_state_dict(cnn_ae.encoder.state_dict())"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "aafa4a9ba7208647",
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"source": [
|
||
|
"Transfer learning"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"id": "a60dd68f988a8249",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:52:09.051066459Z",
|
||
|
"start_time": "2023-10-03T12:49:41.543118970Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
" 80%|████████ | 4/5 [02:06<00:31, 31.63s/it]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"mlp====> Epoch: 5 Average loss: 0.0539\n",
|
||
|
"mlp====> Epoch: 5 Training accuracy: 39.10%\n",
|
||
|
"cnn====> Epoch: 5 Average loss: 0.0333\n",
|
||
|
"cnn====> Epoch: 5 Training accuracy: 61.80%\n",
|
||
|
"====> Test set loss: 0.0536\n",
|
||
|
"====> Test set accuracy: 0.3961\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"100%|██████████| 5/5 [02:27<00:00, 29.50s/it]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"====> Test set loss: 0.0341\n",
|
||
|
"====> Test set accuracy: 0.6133\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# fine-tune the classifier\n",
|
||
|
"learning_rate = 1e-5\n",
|
||
|
"epoch = 20\n",
|
||
|
"optimizer_pretrained = optim.Adam(classifier.parameters(), lr=learning_rate)\n",
|
||
|
"optimizer_pretrained_cnn = optim.Adam(cnn_classifier.parameters(), lr=learning_rate)\n",
|
||
|
"for epoch in tqdm(range(1, epochs + 1)):\n",
|
||
|
" verbose = True if epoch % train_frequency == 0 else False\n",
|
||
|
" train_classifier(classifier, train_loader, optimizer_pretrained, criterion, epoch, verbose)\n",
|
||
|
" train_classifier(cnn_classifier, train_loader, optimizer_cnn, criterion, epoch, verbose)\n",
|
||
|
"\n",
|
||
|
" # test every n epochs\n",
|
||
|
" if epoch % test_frequency == 0:\n",
|
||
|
" test_acc_pretrained = test_classifier(classifier, test_loader, criterion)\n",
|
||
|
" test_acc_pretrained_cnn = test_classifier(cnn_classifier, test_loader, criterion)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "31577275b833707a",
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"source": [
|
||
|
"Compare the results of the linear probing with the results of the linear classifier"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"id": "40d0e7f3f13404c9",
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:52:09.054596162Z",
|
||
|
"start_time": "2023-10-03T12:52:09.050719737Z"
|
||
|
}
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Model Linear Accuracy Non-linear accuracy Pretrained accuracy \n",
|
||
|
"MLP AE 0.3933 0.5129 0.3961 \n",
|
||
|
"CNN AE 0.3981 0.6478 0.6133 \n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# print a table of the accuracies. compare the results with the results of the linear probing\n",
|
||
|
"print(f\"{'Model':<10} {'Linear Accuracy':<20} {'Non-linear accuracy':<20} {'Pretrained accuracy':<20}\")\n",
|
||
|
"print(f\"{'MLP AE':<10} {test_mlp[-1][1]:<20.4f} {test_acc:<20.4f} {test_acc_pretrained:<20.4f}\")\n",
|
||
|
"print(f\"{'CNN AE':<10} {test_cnn[-1][1]:<20.4f} {test_acc_cnn:<20.4f} {test_acc_pretrained_cnn:<20.4f}\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 22,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import torchvision.models as models\n",
|
||
|
"\n",
|
||
|
"class ResNet18Autoencoder(nn.Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super(ResNet18Autoencoder, self).__init__()\n",
|
||
|
" self.type = 'cnn'\n",
|
||
|
" # Encoder: Use pre-trained ResNet18 (without its final fc layer)\n",
|
||
|
" self.resnet18 = models.resnet18(pretrained=False)\n",
|
||
|
" self.encoder = nn.Sequential(*list(self.resnet18.children())[:-1], nn.Flatten())\n",
|
||
|
" \n",
|
||
|
" # Decoder: Create an up-sampling network\n",
|
||
|
" self.decoder = nn.Sequential(\n",
|
||
|
" nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), # 8x8\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 16x16\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 32x32\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 64x64\n",
|
||
|
" nn.ReLU(),\n",
|
||
|
" nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1), # 128x128\n",
|
||
|
" nn.Sigmoid() # to ensure pixel values are between 0 and 1\n",
|
||
|
" )\n",
|
||
|
" \n",
|
||
|
" def forward(self, x):\n",
|
||
|
" x = self.encoder(x)\n",
|
||
|
" # unflatten the output of the encoder to be fed into the decoder\n",
|
||
|
" x = x.view(x.size(0), 512, 1, 1)\n",
|
||
|
" x = self.decoder(x)\n",
|
||
|
" return x\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:52:09.068053637Z",
|
||
|
"start_time": "2023-10-03T12:52:09.054382586Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "e4037285d9d70694"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 23,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"ResNet18 AE parameters: 14476811\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
|
||
|
" warnings.warn(\n",
|
||
|
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n",
|
||
|
" warnings.warn(msg)\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Define the training parameters for the ResNet18 Autoencoder\n",
|
||
|
"batch_size = 128\n",
|
||
|
"epochs = 10\n",
|
||
|
"learning_rate = 1e-3\n",
|
||
|
"train_frequency = epochs\n",
|
||
|
"test_frequency = epochs\n",
|
||
|
"\n",
|
||
|
"# Create the ResNet18 Autoencoder\n",
|
||
|
"resnet18_ae = ResNet18Autoencoder().to(device)\n",
|
||
|
"# print the model's number of parameters\n",
|
||
|
"print(f\"ResNet18 AE parameters: {sum(p.numel() for p in resnet18_ae.parameters())}\")\n",
|
||
|
"\n",
|
||
|
"# Define the loss function and optimizer\n",
|
||
|
"criterion = nn.MSELoss()\n",
|
||
|
"optimizer = optim.Adam(resnet18_ae.parameters(), lr=learning_rate)\n",
|
||
|
"scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, verbose=True)\n",
|
||
|
"# Create the train and test dataloaders\n",
|
||
|
"train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
|
||
|
"test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:52:09.214178804Z",
|
||
|
"start_time": "2023-10-03T12:52:09.061558426Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "8963904ee52b9818"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 24,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Epoch 10 - Train Loss: 0.0013 - Test Loss: 0.0013: 90%|█████████ | 9/10 [01:49<00:11, 11.05s/it]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"cnn====> Epoch: 10 Average loss: 0.0013\n",
|
||
|
"cnn Autoencoder\n",
|
||
|
"====> Test set loss: 0.0013\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
||
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
||
|
"\n",
|
||
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
||
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
||
|
"Please also refer to the documentation for alternative solver options:\n",
|
||
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
||
|
" n_iter_i = _check_optimize_result(\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Train accuracy: 0.36802\n",
|
||
|
"Test accuracy: 0.3665\n",
|
||
|
"KNN accuracy: 0.3889\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Epoch 10 - Train Loss: 0.0013 - Test Loss: 0.0013: 100%|██████████| 10/10 [02:15<00:00, 13.52s/it]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Clustering ARI score: 0.02700985553219709\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Train the model\n",
|
||
|
"test_resnet18 = []\n",
|
||
|
"loss = 0\n",
|
||
|
"pbar = tqdm(range(1, epochs + 1))\n",
|
||
|
"for epoch in pbar:\n",
|
||
|
" verbose = True if epoch % train_frequency == 0 else False\n",
|
||
|
" train_loss = train(resnet18_ae, train_loader, optimizer, criterion, epoch, verbose)\n",
|
||
|
" \n",
|
||
|
" # Update tqdm description with the training loss\n",
|
||
|
" pbar.set_description(f\"Epoch {epoch} - Train Loss: {train_loss:.4f} - Test Loss: {loss:.4f}\")\n",
|
||
|
" \n",
|
||
|
" if epoch % 2 == 0:\n",
|
||
|
" loss = test_loss(resnet18_ae, test_loader, criterion, verbose=False)\n",
|
||
|
" scheduler.step(loss)\n",
|
||
|
"\n",
|
||
|
" # test every n epochs\n",
|
||
|
" if epoch % test_frequency == 0:\n",
|
||
|
" results_resnet_dic = test(resnet18_ae, train_loader, test_loader, criterion)\n",
|
||
|
" test_resnet18.append([results_resnet_dic['reconstruction_loss'], results_resnet_dic['linear_classification_accuracy'], results_resnet_dic['knn_classification_accuracy'], results_resnet_dic['clustering_ari_score']])\n",
|
||
|
" \n",
|
||
|
"\n",
|
||
|
" "
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:54:24.395386156Z",
|
||
|
"start_time": "2023-10-03T12:52:09.200297772Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "7af2592ad350ea8c"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Compare the evaluation results of the ResNet18 Autoencoder with the MLP and CNN Autoencoders "
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "18cc2efab5a2e4b0"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 25,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Model Reconstruction Loss Linear Accuracy KNN Accuracy Clustering ARI \n",
|
||
|
"MLP AE 0.0172 0.3933 0.3391 0.0397 \n",
|
||
|
"CNN AE 0.0044 0.3981 0.3545 0.0516 \n",
|
||
|
"ResNet18 AE 0.0013 0.3665 0.3889 0.0270 \n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(f\"{'Model':<15} {'Reconstruction Loss':<20} {'Linear Accuracy':<20} {'KNN Accuracy':<20} {'Clustering ARI':<20}\")\n",
|
||
|
"print(f\"{'MLP AE':<15} {test_mlp[-1][0]:<20.4f} {test_mlp[-1][1]:<20.4f} {test_mlp[-1][2]:<20.4f} {test_mlp[-1][3]:<20.4f}\")\n",
|
||
|
"print(f\"{'CNN AE':<15} {test_cnn[-1][0]:<20.4f} {test_cnn[-1][1]:<20.4f} {test_cnn[-1][2]:<20.4f} {test_cnn[-1][3]:<20.4f}\")\n",
|
||
|
"print(f\"{'ResNet18 AE':<15} {test_resnet18[-1][0]:<20.4f} {test_resnet18[-1][1]:<20.4f} {test_resnet18[-1][2]:<20.4f} {test_resnet18[-1][3]:<20.4f}\")"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:54:24.401416457Z",
|
||
|
"start_time": "2023-10-03T12:54:24.394398143Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "a5dcaa09ceaf1f3f"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 26,
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class ResNet18Classifier(nn.Module):\n",
|
||
|
" def __init__(self, num_classes=10, pretrained=False):\n",
|
||
|
" super(ResNet18Classifier, self).__init__()\n",
|
||
|
" self.type = 'cnn'\n",
|
||
|
" # Load the ResNet18 model\n",
|
||
|
" self.resnet18 = models.resnet18(pretrained=pretrained)\n",
|
||
|
" \n",
|
||
|
" # Adjust the first convolutional layer for CIFAR-10 image size\n",
|
||
|
" self.resnet18.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
|
||
|
" \n",
|
||
|
" # Adjust the final fully connected layer for CIFAR-10 number of classes\n",
|
||
|
" num_ftrs = self.resnet18.fc.in_features\n",
|
||
|
" self.resnet18.fc = nn.Linear(num_ftrs, num_classes)\n",
|
||
|
" \n",
|
||
|
" # Freeze the encoder weights except the final fc layer\n",
|
||
|
" for param in self.resnet18.parameters():\n",
|
||
|
" param.requires_grad = False\n",
|
||
|
" for param in self.resnet18.fc.parameters():\n",
|
||
|
" param.requires_grad = True\n",
|
||
|
" \n",
|
||
|
" def forward(self, x):\n",
|
||
|
" return self.resnet18(x)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:54:24.409917792Z",
|
||
|
"start_time": "2023-10-03T12:54:24.398893967Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "f4e2e62d832761c0"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 27,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
|
||
|
" warnings.warn(\n",
|
||
|
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n",
|
||
|
" warnings.warn(msg)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"ResNet18 classifier parameters: 11173962\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
|
||
|
" warnings.warn(msg)\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Define the training parameters for the ResNet18 classifier\n",
|
||
|
"batch_size = 128\n",
|
||
|
"epochs = 50\n",
|
||
|
"learning_rate = 1e-4\n",
|
||
|
"train_frequency = epochs\n",
|
||
|
"test_frequency = epochs\n",
|
||
|
"\n",
|
||
|
"# Create the ResNet18 classifier\n",
|
||
|
"resnet18_classifier = ResNet18Classifier(num_classes=10, pretrained=False).to(device)\n",
|
||
|
"resnet18_classifier_pretrained = ResNet18Classifier(num_classes=10, pretrained=True).to(device)\n",
|
||
|
"# print the model's number of parameters\n",
|
||
|
"print(f\"ResNet18 classifier parameters: {sum(p.numel() for p in resnet18_classifier.parameters())}\")\n",
|
||
|
"\n",
|
||
|
"# Define the loss function and optimizer\n",
|
||
|
"criterion = nn.CrossEntropyLoss()\n",
|
||
|
"optimizer = optim.Adam(resnet18_classifier.parameters(), lr=learning_rate)\n",
|
||
|
"optimizer_pretrained = optim.Adam(resnet18_classifier_pretrained.parameters(), lr=learning_rate)\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T12:54:24.817741667Z",
|
||
|
"start_time": "2023-10-03T12:54:24.402570604Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "1f4d0fbf693153da"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 28,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Epoch 49 - Train Loss: 0.0138 - Train Loss Pretrained: 0.0131: 98%|█████████▊| 49/50 [11:20<00:13, 13.33s/it]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"cnn====> Epoch: 50 Average loss: 0.0138\n",
|
||
|
"cnn====> Epoch: 50 Training accuracy: 37.49%\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Epoch 50 - Train Loss: 0.0138 - Train Loss Pretrained: 0.0131: 98%|█████████▊| 49/50 [11:33<00:13, 13.33s/it]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"cnn====> Epoch: 50 Average loss: 0.0131\n",
|
||
|
"cnn====> Epoch: 50 Training accuracy: 41.32%\n",
|
||
|
"====> Test set loss: 0.0142\n",
|
||
|
"====> Test set accuracy: 0.3560\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Epoch 50 - Train Loss: 0.0138 - Train Loss Pretrained: 0.0131: 100%|██████████| 50/50 [11:36<00:00, 13.93s/it]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"====> Test set loss: 0.0136\n",
|
||
|
"====> Test set accuracy: 0.3913\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Train the model\n",
|
||
|
"test_resnet18_classifier = []\n",
|
||
|
"test_resnet18_classifier_pretrained = []\n",
|
||
|
"pbar = tqdm(range(1, epochs + 1))\n",
|
||
|
"\n",
|
||
|
"for epoch in pbar:\n",
|
||
|
" verbose = True if epoch % train_frequency == 0 else False\n",
|
||
|
" train_loss = train_classifier(resnet18_classifier, train_loader, optimizer, criterion, epoch, verbose)\n",
|
||
|
" train_loss_pretrained = train_classifier(resnet18_classifier_pretrained, train_loader, optimizer_pretrained, criterion, epoch, verbose)\n",
|
||
|
" \n",
|
||
|
" # Update tqdm description with the training loss\n",
|
||
|
" pbar.set_description(f\"Epoch {epoch} - Train Loss: {train_loss:.4f} - Train Loss Pretrained: {train_loss_pretrained:.4f}\")\n",
|
||
|
" \n",
|
||
|
" # test every n epochs\n",
|
||
|
" if epoch % test_frequency == 0:\n",
|
||
|
" results_resnet_acc = test_classifier(resnet18_classifier, test_loader, criterion)\n",
|
||
|
" test_resnet18_classifier.append(results_resnet_acc)\n",
|
||
|
" results_resnet_acc_pretrained = test_classifier(resnet18_classifier_pretrained, test_loader, criterion)\n",
|
||
|
" test_resnet18_classifier_pretrained.append(results_resnet_acc_pretrained)\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T13:06:01.403787907Z",
|
||
|
"start_time": "2023-10-03T12:54:24.817533512Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "8452e383237a283"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Compare the accuracy results of the ResNet18 classifier with the MLP and CNN classifiers"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "9b3b561b83de5ccc"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 29,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Model Accuracy \n",
|
||
|
"MLP AE 0.3933 \n",
|
||
|
"CNN AE 0.3981 \n",
|
||
|
"ResNet18 AE 0.3665 \n",
|
||
|
"ResNet18 Classifier 0.3560 \n",
|
||
|
"ResNet18 Classifier Pretrained 0.3913 \n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import numpy as np\n",
|
||
|
"\n",
|
||
|
"print(f\"{'Model':<15} {'Accuracy':<20}\")\n",
|
||
|
"print(f\"{'MLP AE':<15} {test_mlp[-1][1]:<20.4f}\")\n",
|
||
|
"print(f\"{'CNN AE':<15} {test_cnn[-1][1]:<20.4f}\")\n",
|
||
|
"print(f\"{'ResNet18 AE':<15} {test_resnet18[-1][1]:<20.4f}\")\n",
|
||
|
"# take the average of the test accuracies\n",
|
||
|
"print(f\"{'ResNet18 Classifier':<15} {np.mean(test_resnet18_classifier):<20.4f}\")\n",
|
||
|
"print(f\"{'ResNet18 Classifier Pretrained':<15} {np.mean(test_resnet18_classifier_pretrained):<20.4f}\")\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"ExecuteTime": {
|
||
|
"end_time": "2023-10-03T13:06:01.408109408Z",
|
||
|
"start_time": "2023-10-03T13:06:01.403113475Z"
|
||
|
}
|
||
|
},
|
||
|
"id": "8d1f038afa9bcff2"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "9ff6e7674c4a3e71"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"outputs": [],
|
||
|
"source": [],
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"id": "36fdbf815e16107b"
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 2
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython2",
|
||
|
"version": "2.7.6"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|