AppliedMachineAndDeepLearni.../autoencoder_cifar10.ipynb

1623 lines
121 KiB
Plaintext
Raw Permalink Normal View History

2023-10-09 10:02:01 +00:00
{
"cells": [
{
"cell_type": "markdown",
"source": [
"| Credentials | |\n",
"|----|----------------------------------|\n",
"|Host | Montanuniversitaet Leoben |\n",
"|Web | https://cps.unileoben.ac.at |\n",
"|Mail | cps@unileoben.ac.at |\n",
"|Author | Fotios Lygerakis |\n",
"|Corresponding Authors | fotios.lygerakis@unileoben.ac.at |\n",
"|Last edited | 28.09.2023 |"
],
"metadata": {
"collapsed": false
},
"id": "25b3e26209aea07e"
},
{
"cell_type": "markdown",
"id": "1490260facaff836",
"metadata": {
"collapsed": false
},
"source": [
"In the first part of this tutorial we will build a fully connected MLP Autoencoder on the CIFAR10 dataset. Then we will perform linear probing on the encoder features to see how well they perform on a linear classification task."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "fc18830bb6f8d534",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:41:16.496631884Z",
"start_time": "2023-10-03T12:41:14.795432603Z"
}
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import matplotlib.pyplot as plt\n",
"from torchvision import datasets, transforms\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.metrics import adjusted_rand_score\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.cluster import KMeans\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "markdown",
"id": "1bad4bd03deb5b7e",
"metadata": {
"collapsed": false
},
"source": [
"Set random seed"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "27dd48e60ae7dd9e",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:41:16.504621133Z",
"start_time": "2023-10-03T12:41:16.497289235Z"
}
},
"outputs": [
{
"data": {
"text/plain": "<torch._C.Generator at 0x7f098e1490d0>"
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Set random seed\n",
"torch.manual_seed(0)"
]
},
{
"cell_type": "markdown",
"id": "cc7f167a33227e94",
"metadata": {
"collapsed": false
},
"source": [
"Load the CIFAR10 dataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "34248e8bc2678fd3",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:41:18.687951160Z",
"start_time": "2023-10-03T12:41:17.491688103Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n"
]
}
],
"source": [
"# Define the transformations\n",
"transform = transforms.Compose([transforms.ToTensor(),\n",
" transforms.Normalize((0.5,), (0.5,))])\n",
"# Download and load the training data\n",
"trainset = datasets.CIFAR10('data', download=True, train=True, transform=transform)\n",
"# Download and load the test data\n",
"testset = datasets.CIFAR10('data', download=True, train=False, transform=transform)\n"
]
},
{
"cell_type": "markdown",
"id": "928dfac955d0d778",
"metadata": {
"collapsed": false
},
"source": [
"Print some examples from the dataset"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "87c6eae807f51118",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:41:19.826292939Z",
"start_time": "2023-10-03T12:41:19.492130184Z"
}
},
"outputs": [
{
"data": {
"text/plain": "<Figure size 1000x400 with 10 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA8cAAAGJCAYAAACnwkFvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAADLAUlEQVR4nOz9eZhdVZ3vj3/2medzaq5KZahKZZ4IJkxhSEAwMimoDNo2RBBoWy8OF1q5rQI/7XZEUbQFtUUUtZVBLo3K1IIiIDOBEDJXhqqk5jp15mGfvb5/+Etdi/cHOIypynm/nqeflnf2sPaa9tpVWa9YxhgjhBBCCCGEEEJIDeM60AUghBBCCCGEEEIONPw4JoQQQgghhBBS8/DjmBBCCCGEEEJIzcOPY0IIIYQQQgghNQ8/jgkhhBBCCCGE1Dz8OCaEEEIIIYQQUvPw45gQQgghhBBCSM3Dj2NCCCGEEEIIITUPP44JIYQQQgghhNQ8/Dj+/7Nz506xLEu++c1vvmnXfPDBB8WyLHnwwQdf9zV+/vOfy4IFC8Tr9UoikXjTykbI3zNZ+/+BZs2aNbJkyZIDXQxyAOCY0OGYIFNpbKxZs0bWrFnzpl6TTH6mUh+tBV5Le1x11VViWdaErKOjQ9atW/cWlQ6Z0h/HP/3pT8WyLHnyyScPdFHeEjZt2iTr1q2Trq4u+dGPfiQ//OEPD3SRyCTiYO//v/zlL+Xaa6890MUgUwiOCUJ0DvaxQaY+7KOvjVwuJ1dddRU/1t8CpvTH8cHOgw8+KI7jyHe+8x1Zt26dnH322Qe6SIS8bfBDgJCJcEwQQggR+dvH8dVXX33QfRx//vOfl3w+f0DLwI/jSczAwICIyKv+dWpjzAHvSIQcSAqFgjiOc6CLQcikgWOCkANHLpc70EUgZEri8XgkEAgc0DIc9B/HpVJJvvjFL8qKFSskHo9LOByWY489Vh544IGXPefb3/62zJo1S4LBoKxevVo2bNgAx2zatEk+8IEPSH19vQQCAVm5cqXceeedr1qeXC4nmzZtkqGhoVc8rqOjQ6688koREWlqahLLsuSqq64a/7PTTjtN7rnnHlm5cqUEg0G54YYbRERkx44dctZZZ0l9fb2EQiE58sgj5Xe/+x1cf9euXfKe97xHwuGwNDc3y6c//Wm55557uJ/iIGOq9v81a9bI7373O9m1a5dYliWWZUlHR4eI/L99P//1X/8ln//856W9vV1CoZCkUil1r4rI//vrWjt37pyQ/+EPf5DVq1dLNBqVWCwmhx12mPzyl798xbLde++9EgqF5IMf/KDYtv2qz0wmFxwTf4NjgryUqTo29vPDH/5Qurq6JBgMyuGHHy4PPfSQelyxWJQrr7xS5syZI36/X2bMmCH/8i//IsViEY69+eabZcWKFRIMBqW+vl7OPfdc2bNnz4Rj9u/Df+qpp+S4446TUCgk/+f//J+qykxeG1O5jz700ENy1llnycyZM8f73ac//Wn45dbL7ZNft27d+Jy/c+dOaWpqEhGRq6++evydsP87QUTkj3/8oxx77LESDoclkUjIe9/7XnnxxRcnXHP/+2HLli3y4Q9/WOLxuDQ1NckXvvAFMcbInj175L3vfa/EYjFpbW2Va665Bso1MDAgF154obS0tEggEJBDDjlEbrrpppeth1drj5d7Z72UZDIpn/rUp2TGjBni9/tlzpw58rWvfe1N+aGw5w1fYZKTSqXkxz/+sXzwgx+Uiy66SNLptPznf/6nrF27Vh5//HFZvnz5hON/9rOfSTqdlo9//ONSKBTkO9/5jpxwwgny/PPPS0tLi4iIvPDCC3L00UdLe3u7fO5zn5NwOCy/+c1v5IwzzpDbbrtNzjzzzJctz+OPPy7HH3+8XHnllRM68Uu59tpr5Wc/+5n89re/lR/84AcSiURk2bJl43++efNm+eAHPyiXXHKJXHTRRTJ//nzp7++XVatWSS6Xk0svvVQaGhrkpptukve85z1y6623jpcrm83KCSecIPv27ZNPfvKT0traKr/85S9fcXIhU5Op2v//9V//VcbGxqSnp0e+/e1vi4hIJBKZcMyXvvQl8fl8ctlll0mxWBSfz/ea6uanP/2pXHDBBbJ48WK54oorJJFIyDPPPCN33323fOhDH1LPueuuu+QDH/iAnHPOOfKTn/xE3G73a7onOfBwTLw8HBO1zVQdGyIi//mf/ymXXHKJrFq1Sj71qU/Jjh075D3veY/U19fLjBkzxo9zHEfe8573yF/+8he5+OKLZeHChfL888/Lt7/9bdmyZYvccccd48f+27/9m3zhC1+Qs88+Wz760Y/K4OCgXHfddXLcccfJM888M+Fv9Q0PD8vJJ58s5557rnz4wx8ef37y5jKV++gtt9wiuVxOPvaxj0lDQ4M8/vjjct1110lPT4/ccsstr6kempqa5Ac/+IF87GMfkzPPPFPe9773iYiMfyfcf//9cvLJJ8vs2bPlqquuknw+L9ddd50cffTR8vTTT49/ZO/nnHPOkYULF8pXv/pV+d3vfidf/vKXpb6+Xm644QY54YQT5Gtf+5r84he/kMsuu0wOO+wwOe6440REJJ/Py5o1a2Tbtm3yiU98Qjo7O+WWW26RdevWSTKZlE9+8pMT7lNNe1RDLpeT1atXS29vr1xyySUyc+ZMeeSRR+SKK66Qffv2vfHtR2YKc+ONNxoRMU888cTLHmPbtikWixOy0dFR09LSYi644ILxrLu724iICQaDpqenZzx/7LHHjIiYT3/60+PZO9/5TrN06VJTKBTGM8dxzKpVq8zcuXPHswceeMCIiHnggQcgu/LKK1/1+a688kojImZwcHBCPmvWLCMi5u67756Qf+pTnzIiYh566KHxLJ1Om87OTtPR0WEqlYoxxphrrrnGiIi54447xo/L5/NmwYIFUF4yeTnY+/+pp55qZs2aBfn+a8yePdvkcrkJf7Z/zLyU/XXV3d1tjDEmmUyaaDRqjjjiCJPP5ycc6zjO+P9evXq1Wbx4sTHGmNtuu814vV5z0UUXjY8lMrngmOCYIDoH89golUqmubnZLF++fEL5f/jDHxoRMatXrx7Pfv7znxuXyzVhnWSMMddff70REfPwww8bY4zZuXOncbvd5t/+7d8mHPf8888bj8czIV+9erUREXP99de/YjnJK3Mw91FjDMzNxhjzla98xViWZXbt2jWerV69ekKf3c/5558/Yf4fHBx82XsvX77cNDc3m+Hh4fFs/fr1xuVymfPOO2882/9+uPjii8cz27bN9OnTjWVZ5qtf/ep4Pjo6aoLBoDn//PPHs2uvvdaIiLn55pvHs1KpZI466igTiURMKpUyxry29tDeWbNmzZpw3y996UsmHA6bLVu2TDjuc5/7nHG73Wb37t1QJ6+Fg/6vVbvd7vGfnjuOIyMjI2LbtqxcuVKefvppOP6MM86Q9vb28f8+/PDD5YgjjpDf//73IiIyMjIif/zjH+Xss8+WdDotQ0NDMjQ0JMPDw7J27VrZunWr9Pb2vmx51qxZI8aYV/0J06vR2dkpa9eunZD9/ve/l8MPP1yOOeaY8SwSicjFF18sO3fulI0bN4qIyN133y3t7e3ynve8Z/y4QCAgF1100RsqE5l8HKz9X0Tk/PPPl2Aw+LrOve+++ySdTsvnPvc52Nui/XWeX/3qV3LOOefIJZdcIjfccIO4XAf91HnQwjGhwzFBpurYePLJJ2VgYED+6Z/+acLflli3bp3E4/EJx95yyy2ycOFCWbBgwXh5hoaG5IQTThARGf8bdLfffrs4jiNnn332hONaW1tl7ty58Dft/H6/fOQjH3nFcpI3zlTtoyIyYW7OZrMyNDQkq1atEmOMPPPMM9VWwauyb98+efbZZ2XdunVSX18/ni9btkxOOumk8Wf/ez760Y+O/2+32y0rV64UY4xceOGF43kikZD58+fLjh07xrPf//730traKh/84AfHM6/XK5deeqlkMhn505/+NOE+r9Ye1XLLLbfIscceK3V1dRPG54knniiVSkX+/Oc/v6brvZSD/q9Vi4jcdNNNcs0118i
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Get the first 10 samples from CIFAR10\n",
"dataiter = iter(trainset)\n",
"images, labels = [], []\n",
"for i in range(10):\n",
" image, label = next(dataiter)\n",
" images.append(image)\n",
" labels.append(label)\n",
"\n",
"# CIFAR10 label names\n",
"cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n",
"\n",
"# Plot the CIFAR10 samples\n",
"fig, axes = plt.subplots(2, 5, figsize=(10, 4))\n",
"for ax, img, lbl in zip(axes.ravel(), images, labels):\n",
" ax.imshow(img.permute(1, 2, 0).numpy() * 0.5 + 0.5) # denormalize\n",
" ax.set_title(f'Label: {cifar10_classes[lbl]}')\n",
" ax.axis('off')\n",
"plt.tight_layout()\n",
"plt.show()\n"
]
},
{
"cell_type": "markdown",
"id": "e4e25962ef8e5b0d",
"metadata": {
"collapsed": false
},
"source": [
"Define the MLP and Convolutional Autoencoder"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "26f2513d92b78e1e",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:41:21.511063666Z",
"start_time": "2023-10-03T12:41:21.485979560Z"
}
},
"outputs": [],
"source": [
"\n",
"class Autoencoder(nn.Module):\n",
" def __init__(self, input_size, hidden_size, type='mlp'):\n",
" super(Autoencoder, self).__init__()\n",
" self.type = type\n",
" if self.type == 'mlp':\n",
" self.encoder = nn.Sequential(\n",
" nn.Linear(input_size, hidden_size),\n",
" nn.ReLU(True))\n",
" self.decoder = nn.Sequential(\n",
" nn.Linear(hidden_size, input_size),\n",
" nn.ReLU(True),\n",
" nn.Sigmoid()\n",
" )\n",
" elif self.type == 'cnn':\n",
" # Encoder module for CIFAR10\n",
" self.encoder = nn.Sequential(\n",
" nn.Conv2d(input_size, 16, 3, stride=2, padding=1), # 16x16x16\n",
" nn.ReLU(True),\n",
" nn.Conv2d(16, 32, 3, stride=2, padding=1), # 8x8x32\n",
" nn.ReLU(True),\n",
" nn.Conv2d(32, 64, 3, stride=2, padding=1), # 4x4x64\n",
" nn.ReLU(True)\n",
" )\n",
" # Decoder module for CIFAR10\n",
" self.decoder = nn.Sequential(\n",
" nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), # 8x8x32\n",
" nn.ReLU(True),\n",
" nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1), # 16x16x16\n",
" nn.ReLU(True),\n",
" nn.ConvTranspose2d(16, input_size, 3, stride=2, padding=1, output_padding=1), # 32x32x3\n",
" nn.Sigmoid()\n",
" )\n",
" else:\n",
" raise ValueError(f\"Unknown Autoencoder type: {type}\")\n",
" \n",
" def forward(self, x):\n",
" x = self.encoder(x)\n",
" x = self.decoder(x)\n",
" return x\n"
]
},
{
"cell_type": "markdown",
"id": "91a01313b4d95274",
"metadata": {
"collapsed": false
},
"source": [
"Check if GPU support is available"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "67006b35b75d8dff",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:41:22.884901160Z",
"start_time": "2023-10-03T12:41:22.860000300Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
}
],
"source": [
"# device\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)"
]
},
{
"cell_type": "markdown",
"id": "8eebf70cb27640d5",
"metadata": {
"collapsed": false
},
"source": [
"Define the training function"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5f96f7be13984747",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:41:23.954782535Z",
"start_time": "2023-10-03T12:41:23.943574306Z"
}
},
"outputs": [],
"source": [
"# Define the training function\n",
"def train(model, train_loader, optimizer, criterion, epoch, verbose=True):\n",
" model.train()\n",
" train_loss = 0\n",
" for i, (data, _) in enumerate(train_loader):\n",
" # check the type of autoencoder and modify the input data accordingly\n",
" if model.type == 'mlp':\n",
" data = data.view(data.size(0), -1)\n",
" data = data.to(device)\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = criterion(output, data)\n",
" loss.backward()\n",
" train_loss += loss.item()\n",
" optimizer.step()\n",
" train_loss /= len(train_loader.dataset)\n",
" if verbose:\n",
" print(f'{model.type}====> Epoch: {epoch} Average loss: {train_loss:.4f}') \n",
" return train_loss"
]
},
{
"cell_type": "markdown",
"id": "5f6386edcab6b1e4",
"metadata": {
"collapsed": false
},
"source": [
"The evaluation functions for the linear classification"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "b2c4483492fdd427",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:41:24.972220830Z",
"start_time": "2023-10-03T12:41:24.960696742Z"
}
},
"outputs": [],
"source": [
"# Extract encoded representations for a given loader\n",
"def extract_features(loader, model):\n",
" features = []\n",
" labels = []\n",
" model.eval()\n",
" with torch.no_grad():\n",
" for data in loader:\n",
" img, label = data\n",
" if model.type == 'mlp':\n",
" img = img.view(img.size(0), -1)\n",
" img = img.to(device)\n",
" feature = model.encoder(img)\n",
" if model.type == 'cnn':\n",
" feature = feature.view(feature.size(0), -1) # Flatten the CNN encoded features\n",
" features.append(feature)\n",
" labels.append(label)\n",
" return torch.cat(features), torch.cat(labels)\n",
"\n",
"# Define the loss test function\n",
"def test_loss(model, test_loader, criterion, verbose=True):\n",
" model.eval()\n",
" eval_loss = 0\n",
" with torch.no_grad():\n",
" for i, (data, _) in enumerate(test_loader):\n",
" # check the type of autoencoder and modify the input data accordingly\n",
" if model.type == 'mlp':\n",
" data = data.view(data.size(0), -1)\n",
" data = data.to(device)\n",
" output = model(data)\n",
" eval_loss += criterion(output, data).item()\n",
" eval_loss /= len(test_loader.dataset)\n",
" if verbose:\n",
" print('====> Test set loss: {:.4f}'.format(eval_loss))\n",
" return eval_loss\n",
"\n",
"# Define the linear classification test function\n",
"def test_linear(encoded_train, train_labels, encoded_test, test_labels):\n",
" train_features_np = encoded_train.cpu().numpy()\n",
" train_labels_np = train_labels.cpu().numpy()\n",
" test_features_np = encoded_test.cpu().numpy()\n",
" test_labels_np = test_labels.cpu().numpy()\n",
" \n",
" # Apply logistic regression on train features and labels\n",
" logistic_regression = LogisticRegression(random_state=0, max_iter=100).fit(train_features_np, train_labels_np)\n",
" print(f\"Train accuracy: {logistic_regression.score(train_features_np, train_labels_np)}\")\n",
" # Apply logistic regression on test features and labels\n",
" test_accuracy = logistic_regression.score(test_features_np, test_labels_np)\n",
" print(f\"Test accuracy: {test_accuracy}\")\n",
" return test_accuracy\n",
"\n",
"\n",
"def test_clustering(encoded_features, true_labels):\n",
" encoded_features_np = encoded_features.cpu().numpy()\n",
" true_labels_np = true_labels.cpu().numpy()\n",
" \n",
" # Apply k-means clustering\n",
" kmeans = KMeans(n_clusters=10, n_init=10, random_state=0).fit(encoded_features_np)\n",
" cluster_labels = kmeans.labels_\n",
" \n",
" # Evaluate clustering results using Adjusted Rand Index\n",
" ari_score = adjusted_rand_score(true_labels_np, cluster_labels)\n",
" print(f\"Clustering ARI score: {ari_score}\")\n",
" return ari_score\n",
"\n",
"def knn_classifier(encoded_train, train_labels, encoded_test, test_labels, k=5):\n",
" encoded_train_np = encoded_train.cpu().numpy()\n",
" encoded_test_np = encoded_test.cpu().numpy()\n",
" train_labels_np = train_labels.cpu().numpy()\n",
" test_labels_np = test_labels.cpu().numpy()\n",
" \n",
" # Apply k-nearest neighbors classification\n",
" knn = KNeighborsClassifier(n_neighbors=k).fit(encoded_train_np, train_labels_np)\n",
" accuracy_score = knn.score(encoded_test_np, test_labels_np)\n",
" print(f\"KNN accuracy: {accuracy_score}\")\n",
" return accuracy_score\n",
"\n",
"def test(model, train_loader, test_loader, criterion):\n",
" # Extract features once for all tests\n",
" encoded_train, train_labels = extract_features(train_loader, model)\n",
" encoded_test, test_labels = extract_features(test_loader, model)\n",
" print(f\"{model.type} Autoencoder\")\n",
" results = {\n",
" 'reconstruction_loss': test_loss(model, test_loader, criterion),\n",
" 'linear_classification_accuracy': test_linear(encoded_train, train_labels, encoded_test, test_labels),\n",
" 'knn_classification_accuracy': knn_classifier(encoded_train, train_labels, encoded_test, test_labels),\n",
" 'clustering_ari_score': test_clustering(encoded_test, test_labels)\n",
" }\n",
" \n",
" # Save results to a log file\n",
" with open(\"evaluation_results.log\", \"w\") as log_file:\n",
" for key, value in results.items():\n",
" log_file.write(f\"{key}: {value}\")\n",
" \n",
" return results\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "bcb22bc5af9fb014",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:41:26.789225198Z",
"start_time": "2023-10-03T12:41:26.359259721Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MLP AE parameters: 789632\n",
"CNN AE parameters: 47107\n"
]
}
],
"source": [
"# Define the training parameters for the fully connected MLP Autoencoder\n",
"batch_size = 32\n",
"epochs = 5\n",
"input_size = trainset.data.shape[1] * trainset.data.shape[2] * trainset.data.shape[3]\n",
"hidden_size = 128\n",
"train_frequency = epochs\n",
"test_frequency = epochs\n",
"\n",
"# Create the fully connected MLP Autoencoder\n",
"ae = Autoencoder(input_size, hidden_size, type='mlp').to(device)\n",
"input_size = trainset.data.shape[3]\n",
"cnn_ae = Autoencoder(input_size, hidden_size, type='cnn').to(device)\n",
"# print the models' number of parameters\n",
"print(f\"MLP AE parameters: {sum(p.numel() for p in ae.parameters())}\")\n",
"print(f\"CNN AE parameters: {sum(p.numel() for p in cnn_ae.parameters())}\")\n",
"\n",
"# Define the loss function and optimizer\n",
"criterion = nn.MSELoss()\n",
"optimizer = optim.Adam(ae.parameters(), lr=1e-3)\n",
"optimizer_cnn = optim.Adam(cnn_ae.parameters(), lr=1e-3)\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f1626ce45bb25883",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:41:27.399274149Z",
"start_time": "2023-10-03T12:41:27.378668195Z"
}
},
"outputs": [],
"source": [
"# Create the train and test dataloaders\n",
"train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
"test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "7472159fdc5f2532",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:47:02.996829577Z",
"start_time": "2023-10-03T12:41:27.969843962Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 80%|████████ | 4/5 [01:21<00:22, 22.85s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mlp====> Epoch: 5 Average loss: 0.0174\n",
"cnn====> Epoch: 5 Average loss: 0.0045\n",
"mlp Autoencoder\n",
"====> Test set loss: 0.0172\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" n_iter_i = _check_optimize_result(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train accuracy: 0.39372\n",
"Test accuracy: 0.3933\n",
"KNN accuracy: 0.3391\n",
"Clustering ARI score: 0.03965075119494781\n",
"cnn Autoencoder\n",
"====> Test set loss: 0.0044\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" n_iter_i = _check_optimize_result(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train accuracy: 0.41372\n",
"Test accuracy: 0.3981\n",
"KNN accuracy: 0.3545\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 5/5 [05:34<00:00, 67.00s/it] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Clustering ARI score: 0.051552759482472406\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"test_mlp = []\n",
"test_cnn = []\n",
"# Train the model\n",
"for epoch in tqdm(range(1, epochs + 1)):\n",
" verbose = True if epoch % train_frequency == 0 else False\n",
" train(ae, train_loader, optimizer, criterion, epoch, verbose)\n",
" train(cnn_ae, train_loader, optimizer_cnn, criterion, epoch, verbose)\n",
"\n",
" # test every n epochs\n",
" if epoch % test_frequency == 0:\n",
" restults_dic = test(ae, train_loader, test_loader, criterion)\n",
" test_mlp.append([restults_dic['reconstruction_loss'], restults_dic['linear_classification_accuracy'], restults_dic['knn_classification_accuracy'], restults_dic['clustering_ari_score']])\n",
" restults_dic = test(cnn_ae, train_loader, test_loader, criterion)\n",
" test_cnn.append([restults_dic['reconstruction_loss'], restults_dic['linear_classification_accuracy'], restults_dic['knn_classification_accuracy'], restults_dic['clustering_ari_score']])"
]
},
{
"cell_type": "markdown",
"id": "10639256e342a159",
"metadata": {
"collapsed": false
},
"source": [
"Compare the evaluation results of the MLP and CNN Autoencoders"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "50bb4c3c58af09ee",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T13:06:01.418631529Z",
"start_time": "2023-10-03T13:06:01.405317147Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model Reconstruction Loss Linear Accuracy KNN Accuracy Clustering ARI \n",
"MLP AE 0.0172 0.3933 0.3391 0.0397 \n",
"CNN AE 0.0044 0.3981 0.3545 0.0516 \n"
]
}
],
"source": [
"print(f\"{'Model':<10} {'Reconstruction Loss':<20} {'Linear Accuracy':<20} {'KNN Accuracy':<20} {'Clustering ARI':<20}\")\n",
"print(f\"{'MLP AE':<10} {test_mlp[-1][0]:<20.4f} {test_mlp[-1][1]:<20.4f} {test_mlp[-1][2]:<20.4f} {test_mlp[-1][3]:<20.4f}\")\n",
"print(f\"{'CNN AE':<10} {test_cnn[-1][0]:<20.4f} {test_cnn[-1][1]:<20.4f} {test_cnn[-1][2]:<20.4f} {test_cnn[-1][3]:<20.4f}\")"
]
},
{
"cell_type": "markdown",
"id": "b9201d1403781706",
"metadata": {
"collapsed": false
},
"source": [
"Develop a linear classifier with fully connected layers"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "1612800950703181",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:47:03.004530043Z",
"start_time": "2023-10-03T12:47:02.996038048Z"
}
},
"outputs": [],
"source": [
"# Define the fully connected classifier for MNIST\n",
"class DenseClassifier(nn.Module):\n",
" def __init__(self, input_size=784, hidden_size=500, num_classes=10):\n",
" super(DenseClassifier, self).__init__()\n",
" self.type = 'mlp'\n",
" self.encoder = nn.Sequential(\n",
" nn.Linear(input_size, hidden_size),\n",
" nn.ReLU(True))\n",
" self.fc1 = nn.Linear(hidden_size, num_classes)\n",
"\n",
" def forward(self, x):\n",
" x = x.view(x.size(0), -1) # Flatten the input tensor\n",
" x = self.encoder(x)\n",
" x = self.fc1(x)\n",
" return x\n"
]
},
{
"cell_type": "markdown",
"id": "35db4190e9c7f716",
"metadata": {
"collapsed": false
},
"source": [
"Develop a non-linear classifier with convolutional layers"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "cb2dfcf75113fd0b",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:47:03.041289873Z",
"start_time": "2023-10-03T12:47:02.999128924Z"
}
},
"outputs": [],
"source": [
"# cnn classifier\n",
"class CNNClassifier(nn.Module):\n",
" def __init__(self, input_size, hidden_size=128, num_classes=10):\n",
" super(CNNClassifier, self).__init__()\n",
" self.type = 'cnn'\n",
" # Encoder (Feature extractor)\n",
" self.encoder = nn.Sequential(\n",
" nn.Conv2d(input_size, 16, 3, stride=2, padding=1), # 16x16x16\n",
" nn.ReLU(True),\n",
" nn.Conv2d(16, 32, 3, stride=2, padding=1), # 8x8x32\n",
" nn.ReLU(True),\n",
" nn.Conv2d(32, 64, 3, stride=2, padding=1), # 4x4x64\n",
" nn.ReLU(True)\n",
" )\n",
" \n",
" # Classifier\n",
" self.classifier = nn.Sequential(\n",
" nn.Flatten(), \n",
" nn.Linear(4*4*64, hidden_size),\n",
" nn.ReLU(True),\n",
" nn.Linear(hidden_size, num_classes)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.encoder(x)\n",
" x = self.classifier(x)\n",
" return x\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "7c3cf2371479da0c",
"metadata": {
"collapsed": false
},
"source": [
"Train and test functions for the non-linear classifier"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "ac980d25bd8a3dd3",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:47:03.062646362Z",
"start_time": "2023-10-03T12:47:03.020366991Z"
}
},
"outputs": [],
"source": [
"# Train for the classifier\n",
"def train_classifier(model, train_loader, optimizer, criterion, epoch, verbose=True):\n",
" model.train()\n",
" train_loss = 0\n",
" correct = 0 \n",
" for i, (data, target) in enumerate(train_loader):\n",
" if model.type == 'cnn':\n",
" data = data.to(device)\n",
" else:\n",
" data = data.view(data.size(0), -1)\n",
" data = data.to(device)\n",
" target = target.to(device)\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = criterion(output, target)\n",
" loss.backward()\n",
" train_loss += loss.item()\n",
" optimizer.step()\n",
" # Calculate correct predictions for training accuracy\n",
" pred = output.argmax(dim=1, keepdim=True)\n",
" correct += pred.eq(target.view_as(pred)).sum().item()\n",
"\n",
" train_loss /= len(train_loader.dataset)\n",
" train_accuracy = 100. * correct / len(train_loader.dataset)\n",
" if verbose:\n",
" print(f'{model.type}====> Epoch: {epoch} Average loss: {train_loss:.4f}')\n",
" print(f'{model.type}====> Epoch: {epoch} Training accuracy: {train_accuracy:.2f}%')\n",
" return train_loss\n",
"\n",
"\n",
"def test_classifier(model, test_loader, criterion):\n",
" model.eval()\n",
" eval_loss = 0\n",
" correct = 0\n",
" with torch.no_grad():\n",
" for i, (data, target) in enumerate(test_loader):\n",
" if model.type == 'cnn':\n",
" data = data.to(device)\n",
" else:\n",
" data = data.view(data.size(0), -1)\n",
" data = data.to(device)\n",
" target = target.to(device)\n",
" output = model(data)\n",
" eval_loss += criterion(output, target).item()\n",
" pred = output.argmax(dim=1, keepdim=True)\n",
" correct += pred.eq(target.view_as(pred)).sum().item()\n",
" eval_loss /= len(test_loader.dataset)\n",
" print('====> Test set loss: {:.4f}'.format(eval_loss))\n",
" accuracy = correct / len(test_loader.dataset)\n",
" print('====> Test set accuracy: {:.4f}'.format(accuracy))\n",
" return accuracy\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "dff05e622dcfd774",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:47:03.092600803Z",
"start_time": "2023-10-03T12:47:03.049595871Z"
}
},
"outputs": [],
"source": [
"# Define the training parameters for the fully connected classifier\n",
"batch_size = 32\n",
"epochs = 5\n",
"learning_rate = 1e-3\n",
"hidden_size = 128\n",
"num_classes = 10\n",
"train_frequency = epochs\n",
"test_frequency = epochs\n",
"# Create the fully connected classifier\n",
"input_size = trainset.data.shape[1] * trainset.data.shape[2] * trainset.data.shape[3]\n",
"classifier = DenseClassifier(input_size, hidden_size, num_classes).to(device)\n",
"input_size = trainset.data.shape[3]\n",
"cnn_classifier = CNNClassifier(input_size, hidden_size, num_classes).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "3104345cdee0eb00",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:47:03.095971390Z",
"start_time": "2023-10-03T12:47:03.080480468Z"
}
},
"outputs": [],
"source": [
"# Define the loss function and optimizer\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(classifier.parameters(), lr=learning_rate)\n",
"optimizer_cnn = optim.Adam(cnn_classifier.parameters(), lr=learning_rate)\n"
]
},
{
"cell_type": "markdown",
"id": "e1fed39be2f04745",
"metadata": {
"collapsed": false
},
"source": [
"Train the non-linear classifiers"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "abc0c6ce338d40d9",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:49:41.536281525Z",
"start_time": "2023-10-03T12:47:03.094970008Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 80%|████████ | 4/5 [02:03<00:30, 30.65s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mlp====> Epoch: 5 Average loss: 0.0401\n",
"mlp====> Epoch: 5 Training accuracy: 55.44%\n",
"cnn====> Epoch: 5 Average loss: 0.0265\n",
"cnn====> Epoch: 5 Training accuracy: 69.87%\n",
"====> Test set loss: 0.0446\n",
"====> Test set accuracy: 0.5129\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 5/5 [02:38<00:00, 31.69s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"====> Test set loss: 0.0319\n",
"====> Test set accuracy: 0.6478\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Train the classifier\n",
"for epoch in tqdm(range(1, epochs + 1)):\n",
" verbose = True if epoch % train_frequency == 0 else False\n",
" train_classifier(classifier, train_loader, optimizer, criterion, epoch, verbose)\n",
" train_classifier(cnn_classifier, train_loader, optimizer_cnn, criterion, epoch, verbose)\n",
"\n",
" # test every n epochs\n",
" if epoch % test_frequency == 0:\n",
" test_acc = test_classifier(classifier, test_loader, criterion)\n",
" test_acc_cnn = test_classifier(cnn_classifier, test_loader, criterion)\n"
]
},
{
"cell_type": "markdown",
"id": "a06038f113d8434f",
"metadata": {
"collapsed": false
},
"source": [
"Load the encoder weights into the classifier"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "6a91d8894b70ef7c",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:49:41.571816419Z",
"start_time": "2023-10-03T12:49:41.535695938Z"
}
},
"outputs": [
{
"data": {
"text/plain": "<All keys matched successfully>"
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# initialize the classifier with the encoder weights\n",
"classifier.encoder.load_state_dict(ae.encoder.state_dict())\n",
"cnn_classifier.encoder.load_state_dict(cnn_ae.encoder.state_dict())"
]
},
{
"cell_type": "markdown",
"id": "aafa4a9ba7208647",
"metadata": {
"collapsed": false
},
"source": [
"Transfer learning"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "a60dd68f988a8249",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:52:09.051066459Z",
"start_time": "2023-10-03T12:49:41.543118970Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 80%|████████ | 4/5 [02:06<00:31, 31.63s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"mlp====> Epoch: 5 Average loss: 0.0539\n",
"mlp====> Epoch: 5 Training accuracy: 39.10%\n",
"cnn====> Epoch: 5 Average loss: 0.0333\n",
"cnn====> Epoch: 5 Training accuracy: 61.80%\n",
"====> Test set loss: 0.0536\n",
"====> Test set accuracy: 0.3961\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 5/5 [02:27<00:00, 29.50s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"====> Test set loss: 0.0341\n",
"====> Test set accuracy: 0.6133\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# fine-tune the classifier\n",
"learning_rate = 1e-5\n",
"epoch = 20\n",
"optimizer_pretrained = optim.Adam(classifier.parameters(), lr=learning_rate)\n",
"optimizer_pretrained_cnn = optim.Adam(cnn_classifier.parameters(), lr=learning_rate)\n",
"for epoch in tqdm(range(1, epochs + 1)):\n",
" verbose = True if epoch % train_frequency == 0 else False\n",
" train_classifier(classifier, train_loader, optimizer_pretrained, criterion, epoch, verbose)\n",
" train_classifier(cnn_classifier, train_loader, optimizer_cnn, criterion, epoch, verbose)\n",
"\n",
" # test every n epochs\n",
" if epoch % test_frequency == 0:\n",
" test_acc_pretrained = test_classifier(classifier, test_loader, criterion)\n",
" test_acc_pretrained_cnn = test_classifier(cnn_classifier, test_loader, criterion)\n"
]
},
{
"cell_type": "markdown",
"id": "31577275b833707a",
"metadata": {
"collapsed": false
},
"source": [
"Compare the results of the linear probing with the results of the linear classifier"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "40d0e7f3f13404c9",
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:52:09.054596162Z",
"start_time": "2023-10-03T12:52:09.050719737Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model Linear Accuracy Non-linear accuracy Pretrained accuracy \n",
"MLP AE 0.3933 0.5129 0.3961 \n",
"CNN AE 0.3981 0.6478 0.6133 \n"
]
}
],
"source": [
"# print a table of the accuracies. compare the results with the results of the linear probing\n",
"print(f\"{'Model':<10} {'Linear Accuracy':<20} {'Non-linear accuracy':<20} {'Pretrained accuracy':<20}\")\n",
"print(f\"{'MLP AE':<10} {test_mlp[-1][1]:<20.4f} {test_acc:<20.4f} {test_acc_pretrained:<20.4f}\")\n",
"print(f\"{'CNN AE':<10} {test_cnn[-1][1]:<20.4f} {test_acc_cnn:<20.4f} {test_acc_pretrained_cnn:<20.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 22,
"outputs": [],
"source": [
"import torchvision.models as models\n",
"\n",
"class ResNet18Autoencoder(nn.Module):\n",
" def __init__(self):\n",
" super(ResNet18Autoencoder, self).__init__()\n",
" self.type = 'cnn'\n",
" # Encoder: Use pre-trained ResNet18 (without its final fc layer)\n",
" self.resnet18 = models.resnet18(pretrained=False)\n",
" self.encoder = nn.Sequential(*list(self.resnet18.children())[:-1], nn.Flatten())\n",
" \n",
" # Decoder: Create an up-sampling network\n",
" self.decoder = nn.Sequential(\n",
" nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), # 8x8\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 16x16\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 32x32\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 64x64\n",
" nn.ReLU(),\n",
" nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1), # 128x128\n",
" nn.Sigmoid() # to ensure pixel values are between 0 and 1\n",
" )\n",
" \n",
" def forward(self, x):\n",
" x = self.encoder(x)\n",
" # unflatten the output of the encoder to be fed into the decoder\n",
" x = x.view(x.size(0), 512, 1, 1)\n",
" x = self.decoder(x)\n",
" return x\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:52:09.068053637Z",
"start_time": "2023-10-03T12:52:09.054382586Z"
}
},
"id": "e4037285d9d70694"
},
{
"cell_type": "code",
"execution_count": 23,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ResNet18 AE parameters: 14476811\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
" warnings.warn(\n",
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n",
" warnings.warn(msg)\n"
]
}
],
"source": [
"# Define the training parameters for the ResNet18 Autoencoder\n",
"batch_size = 128\n",
"epochs = 10\n",
"learning_rate = 1e-3\n",
"train_frequency = epochs\n",
"test_frequency = epochs\n",
"\n",
"# Create the ResNet18 Autoencoder\n",
"resnet18_ae = ResNet18Autoencoder().to(device)\n",
"# print the model's number of parameters\n",
"print(f\"ResNet18 AE parameters: {sum(p.numel() for p in resnet18_ae.parameters())}\")\n",
"\n",
"# Define the loss function and optimizer\n",
"criterion = nn.MSELoss()\n",
"optimizer = optim.Adam(resnet18_ae.parameters(), lr=learning_rate)\n",
"scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, verbose=True)\n",
"# Create the train and test dataloaders\n",
"train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
"test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:52:09.214178804Z",
"start_time": "2023-10-03T12:52:09.061558426Z"
}
},
"id": "8963904ee52b9818"
},
{
"cell_type": "code",
"execution_count": 24,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 10 - Train Loss: 0.0013 - Test Loss: 0.0013: 90%|█████████ | 9/10 [01:49<00:11, 11.05s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"cnn====> Epoch: 10 Average loss: 0.0013\n",
"cnn Autoencoder\n",
"====> Test set loss: 0.0013\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" n_iter_i = _check_optimize_result(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train accuracy: 0.36802\n",
"Test accuracy: 0.3665\n",
"KNN accuracy: 0.3889\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 10 - Train Loss: 0.0013 - Test Loss: 0.0013: 100%|██████████| 10/10 [02:15<00:00, 13.52s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Clustering ARI score: 0.02700985553219709\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Train the model\n",
"test_resnet18 = []\n",
"loss = 0\n",
"pbar = tqdm(range(1, epochs + 1))\n",
"for epoch in pbar:\n",
" verbose = True if epoch % train_frequency == 0 else False\n",
" train_loss = train(resnet18_ae, train_loader, optimizer, criterion, epoch, verbose)\n",
" \n",
" # Update tqdm description with the training loss\n",
" pbar.set_description(f\"Epoch {epoch} - Train Loss: {train_loss:.4f} - Test Loss: {loss:.4f}\")\n",
" \n",
" if epoch % 2 == 0:\n",
" loss = test_loss(resnet18_ae, test_loader, criterion, verbose=False)\n",
" scheduler.step(loss)\n",
"\n",
" # test every n epochs\n",
" if epoch % test_frequency == 0:\n",
" results_resnet_dic = test(resnet18_ae, train_loader, test_loader, criterion)\n",
" test_resnet18.append([results_resnet_dic['reconstruction_loss'], results_resnet_dic['linear_classification_accuracy'], results_resnet_dic['knn_classification_accuracy'], results_resnet_dic['clustering_ari_score']])\n",
" \n",
"\n",
" "
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:54:24.395386156Z",
"start_time": "2023-10-03T12:52:09.200297772Z"
}
},
"id": "7af2592ad350ea8c"
},
{
"cell_type": "markdown",
"source": [
"Compare the evaluation results of the ResNet18 Autoencoder with the MLP and CNN Autoencoders "
],
"metadata": {
"collapsed": false
},
"id": "18cc2efab5a2e4b0"
},
{
"cell_type": "code",
"execution_count": 25,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model Reconstruction Loss Linear Accuracy KNN Accuracy Clustering ARI \n",
"MLP AE 0.0172 0.3933 0.3391 0.0397 \n",
"CNN AE 0.0044 0.3981 0.3545 0.0516 \n",
"ResNet18 AE 0.0013 0.3665 0.3889 0.0270 \n"
]
}
],
"source": [
"print(f\"{'Model':<15} {'Reconstruction Loss':<20} {'Linear Accuracy':<20} {'KNN Accuracy':<20} {'Clustering ARI':<20}\")\n",
"print(f\"{'MLP AE':<15} {test_mlp[-1][0]:<20.4f} {test_mlp[-1][1]:<20.4f} {test_mlp[-1][2]:<20.4f} {test_mlp[-1][3]:<20.4f}\")\n",
"print(f\"{'CNN AE':<15} {test_cnn[-1][0]:<20.4f} {test_cnn[-1][1]:<20.4f} {test_cnn[-1][2]:<20.4f} {test_cnn[-1][3]:<20.4f}\")\n",
"print(f\"{'ResNet18 AE':<15} {test_resnet18[-1][0]:<20.4f} {test_resnet18[-1][1]:<20.4f} {test_resnet18[-1][2]:<20.4f} {test_resnet18[-1][3]:<20.4f}\")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:54:24.401416457Z",
"start_time": "2023-10-03T12:54:24.394398143Z"
}
},
"id": "a5dcaa09ceaf1f3f"
},
{
"cell_type": "code",
"execution_count": 26,
"outputs": [],
"source": [
"class ResNet18Classifier(nn.Module):\n",
" def __init__(self, num_classes=10, pretrained=False):\n",
" super(ResNet18Classifier, self).__init__()\n",
" self.type = 'cnn'\n",
" # Load the ResNet18 model\n",
" self.resnet18 = models.resnet18(pretrained=pretrained)\n",
" \n",
" # Adjust the first convolutional layer for CIFAR-10 image size\n",
" self.resnet18.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
" \n",
" # Adjust the final fully connected layer for CIFAR-10 number of classes\n",
" num_ftrs = self.resnet18.fc.in_features\n",
" self.resnet18.fc = nn.Linear(num_ftrs, num_classes)\n",
" \n",
" # Freeze the encoder weights except the final fc layer\n",
" for param in self.resnet18.parameters():\n",
" param.requires_grad = False\n",
" for param in self.resnet18.fc.parameters():\n",
" param.requires_grad = True\n",
" \n",
" def forward(self, x):\n",
" return self.resnet18(x)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:54:24.409917792Z",
"start_time": "2023-10-03T12:54:24.398893967Z"
}
},
"id": "f4e2e62d832761c0"
},
{
"cell_type": "code",
"execution_count": 27,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
" warnings.warn(\n",
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n",
" warnings.warn(msg)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ResNet18 classifier parameters: 11173962\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/fotis/PycharmProjects/representation_learning_tutorial/venv/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.\n",
" warnings.warn(msg)\n"
]
}
],
"source": [
"# Define the training parameters for the ResNet18 classifier\n",
"batch_size = 128\n",
"epochs = 50\n",
"learning_rate = 1e-4\n",
"train_frequency = epochs\n",
"test_frequency = epochs\n",
"\n",
"# Create the ResNet18 classifier\n",
"resnet18_classifier = ResNet18Classifier(num_classes=10, pretrained=False).to(device)\n",
"resnet18_classifier_pretrained = ResNet18Classifier(num_classes=10, pretrained=True).to(device)\n",
"# print the model's number of parameters\n",
"print(f\"ResNet18 classifier parameters: {sum(p.numel() for p in resnet18_classifier.parameters())}\")\n",
"\n",
"# Define the loss function and optimizer\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(resnet18_classifier.parameters(), lr=learning_rate)\n",
"optimizer_pretrained = optim.Adam(resnet18_classifier_pretrained.parameters(), lr=learning_rate)\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T12:54:24.817741667Z",
"start_time": "2023-10-03T12:54:24.402570604Z"
}
},
"id": "1f4d0fbf693153da"
},
{
"cell_type": "code",
"execution_count": 28,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 49 - Train Loss: 0.0138 - Train Loss Pretrained: 0.0131: 98%|█████████▊| 49/50 [11:20<00:13, 13.33s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"cnn====> Epoch: 50 Average loss: 0.0138\n",
"cnn====> Epoch: 50 Training accuracy: 37.49%\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 50 - Train Loss: 0.0138 - Train Loss Pretrained: 0.0131: 98%|█████████▊| 49/50 [11:33<00:13, 13.33s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"cnn====> Epoch: 50 Average loss: 0.0131\n",
"cnn====> Epoch: 50 Training accuracy: 41.32%\n",
"====> Test set loss: 0.0142\n",
"====> Test set accuracy: 0.3560\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 50 - Train Loss: 0.0138 - Train Loss Pretrained: 0.0131: 100%|██████████| 50/50 [11:36<00:00, 13.93s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"====> Test set loss: 0.0136\n",
"====> Test set accuracy: 0.3913\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Train the model\n",
"test_resnet18_classifier = []\n",
"test_resnet18_classifier_pretrained = []\n",
"pbar = tqdm(range(1, epochs + 1))\n",
"\n",
"for epoch in pbar:\n",
" verbose = True if epoch % train_frequency == 0 else False\n",
" train_loss = train_classifier(resnet18_classifier, train_loader, optimizer, criterion, epoch, verbose)\n",
" train_loss_pretrained = train_classifier(resnet18_classifier_pretrained, train_loader, optimizer_pretrained, criterion, epoch, verbose)\n",
" \n",
" # Update tqdm description with the training loss\n",
" pbar.set_description(f\"Epoch {epoch} - Train Loss: {train_loss:.4f} - Train Loss Pretrained: {train_loss_pretrained:.4f}\")\n",
" \n",
" # test every n epochs\n",
" if epoch % test_frequency == 0:\n",
" results_resnet_acc = test_classifier(resnet18_classifier, test_loader, criterion)\n",
" test_resnet18_classifier.append(results_resnet_acc)\n",
" results_resnet_acc_pretrained = test_classifier(resnet18_classifier_pretrained, test_loader, criterion)\n",
" test_resnet18_classifier_pretrained.append(results_resnet_acc_pretrained)\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T13:06:01.403787907Z",
"start_time": "2023-10-03T12:54:24.817533512Z"
}
},
"id": "8452e383237a283"
},
{
"cell_type": "markdown",
"source": [
"Compare the accuracy results of the ResNet18 classifier with the MLP and CNN classifiers"
],
"metadata": {
"collapsed": false
},
"id": "9b3b561b83de5ccc"
},
{
"cell_type": "code",
"execution_count": 29,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model Accuracy \n",
"MLP AE 0.3933 \n",
"CNN AE 0.3981 \n",
"ResNet18 AE 0.3665 \n",
"ResNet18 Classifier 0.3560 \n",
"ResNet18 Classifier Pretrained 0.3913 \n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"print(f\"{'Model':<15} {'Accuracy':<20}\")\n",
"print(f\"{'MLP AE':<15} {test_mlp[-1][1]:<20.4f}\")\n",
"print(f\"{'CNN AE':<15} {test_cnn[-1][1]:<20.4f}\")\n",
"print(f\"{'ResNet18 AE':<15} {test_resnet18[-1][1]:<20.4f}\")\n",
"# take the average of the test accuracies\n",
"print(f\"{'ResNet18 Classifier':<15} {np.mean(test_resnet18_classifier):<20.4f}\")\n",
"print(f\"{'ResNet18 Classifier Pretrained':<15} {np.mean(test_resnet18_classifier_pretrained):<20.4f}\")\n"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2023-10-03T13:06:01.408109408Z",
"start_time": "2023-10-03T13:06:01.403113475Z"
}
},
"id": "8d1f038afa9bcff2"
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"collapsed": false
},
"id": "9ff6e7674c4a3e71"
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
},
"id": "36fdbf815e16107b"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}