update assignment 4 solution

This commit is contained in:
ligerfotis 2023-05-16 15:02:56 +02:00
parent 41741ca8f4
commit 1b75b09719

View File

@ -13,7 +13,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 2,
"outputs": [], "outputs": [],
"source": [ "source": [
"import pandas as pd\n", "import pandas as pd\n",
@ -85,7 +85,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 3,
"metadata": { "metadata": {
"collapsed": true "collapsed": true
}, },
@ -141,14 +141,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 4,
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": " age sex bmi bp s1 s2 s3 \\\n0 0.794887 1.061173 1.357096 0.459459 -0.917834 -0.734476 -0.958901 \n1 -0.038221 -0.940162 -1.095193 -0.557425 -0.148672 -0.395182 1.714481 \n2 1.779468 1.061173 0.983414 -0.121617 -0.947417 -0.720904 -0.708271 \n3 -1.855910 -0.940162 -0.231053 -0.775328 0.295075 0.561626 -0.791815 \n4 0.113253 -0.940162 -0.768221 0.459459 0.117576 0.358049 0.210704 \n\n s4 s5 s6 target \n0 -0.035628 0.434041 -0.356981 151.0 \n1 -0.856638 -1.429397 -1.923328 75.0 \n2 -0.035628 0.074059 -0.531020 141.0 \n3 0.785382 0.492755 -0.182943 206.0 \n4 -0.035628 -0.661884 -0.966116 135.0 ", "text/plain": " age sex bmi bp s1 s2 s3 \\\n0 0.794887 1.061173 1.357096 0.459459 -0.917834 -0.734476 -0.958901 \n1 -0.038221 -0.940162 -1.095193 -0.557425 -0.148672 -0.395182 1.714481 \n2 1.779468 1.061173 0.983414 -0.121617 -0.947417 -0.720904 -0.708271 \n3 -1.855910 -0.940162 -0.231053 -0.775328 0.295075 0.561626 -0.791815 \n4 0.113253 -0.940162 -0.768221 0.459459 0.117576 0.358049 0.210704 \n\n s4 s5 s6 target \n0 -0.035628 0.434041 -0.356981 151.0 \n1 -0.856638 -1.429397 -1.923328 75.0 \n2 -0.035628 0.074059 -0.531020 141.0 \n3 0.785382 0.492755 -0.182943 206.0 \n4 -0.035628 -0.661884 -0.966116 135.0 ",
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>age</th>\n <th>sex</th>\n <th>bmi</th>\n <th>bp</th>\n <th>s1</th>\n <th>s2</th>\n <th>s3</th>\n <th>s4</th>\n <th>s5</th>\n <th>s6</th>\n <th>target</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>0.794887</td>\n <td>1.061173</td>\n <td>1.357096</td>\n <td>0.459459</td>\n <td>-0.917834</td>\n <td>-0.734476</td>\n <td>-0.958901</td>\n <td>-0.035628</td>\n <td>0.434041</td>\n <td>-0.356981</td>\n <td>151.0</td>\n </tr>\n <tr>\n <th>1</th>\n <td>-0.038221</td>\n <td>-0.940162</td>\n <td>-1.095193</td>\n <td>-0.557425</td>\n <td>-0.148672</td>\n <td>-0.395182</td>\n <td>1.714481</td>\n <td>-0.856638</td>\n <td>-1.429397</td>\n <td>-1.923328</td>\n <td>75.0</td>\n </tr>\n <tr>\n <th>2</th>\n <td>1.779468</td>\n <td>1.061173</td>\n <td>0.983414</td>\n <td>-0.121617</td>\n <td>-0.947417</td>\n <td>-0.720904</td>\n <td>-0.708271</td>\n <td>-0.035628</td>\n <td>0.074059</td>\n <td>-0.531020</td>\n <td>141.0</td>\n </tr>\n <tr>\n <th>3</th>\n <td>-1.855910</td>\n <td>-0.940162</td>\n <td>-0.231053</td>\n <td>-0.775328</td>\n <td>0.295075</td>\n <td>0.561626</td>\n <td>-0.791815</td>\n <td>0.785382</td>\n <td>0.492755</td>\n <td>-0.182943</td>\n <td>206.0</td>\n </tr>\n <tr>\n <th>4</th>\n <td>0.113253</td>\n <td>-0.940162</td>\n <td>-0.768221</td>\n <td>0.459459</td>\n <td>0.117576</td>\n <td>0.358049</td>\n <td>0.210704</td>\n <td>-0.035628</td>\n <td>-0.661884</td>\n <td>-0.966116</td>\n <td>135.0</td>\n </tr>\n </tbody>\n</table>\n</div>" "text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>age</th>\n <th>sex</th>\n <th>bmi</th>\n <th>bp</th>\n <th>s1</th>\n <th>s2</th>\n <th>s3</th>\n <th>s4</th>\n <th>s5</th>\n <th>s6</th>\n <th>target</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>0.794887</td>\n <td>1.061173</td>\n <td>1.357096</td>\n <td>0.459459</td>\n <td>-0.917834</td>\n <td>-0.734476</td>\n <td>-0.958901</td>\n <td>-0.035628</td>\n <td>0.434041</td>\n <td>-0.356981</td>\n <td>151.0</td>\n </tr>\n <tr>\n <th>1</th>\n <td>-0.038221</td>\n <td>-0.940162</td>\n <td>-1.095193</td>\n <td>-0.557425</td>\n <td>-0.148672</td>\n <td>-0.395182</td>\n <td>1.714481</td>\n <td>-0.856638</td>\n <td>-1.429397</td>\n <td>-1.923328</td>\n <td>75.0</td>\n </tr>\n <tr>\n <th>2</th>\n <td>1.779468</td>\n <td>1.061173</td>\n <td>0.983414</td>\n <td>-0.121617</td>\n <td>-0.947417</td>\n <td>-0.720904</td>\n <td>-0.708271</td>\n <td>-0.035628</td>\n <td>0.074059</td>\n <td>-0.531020</td>\n <td>141.0</td>\n </tr>\n <tr>\n <th>3</th>\n <td>-1.855910</td>\n <td>-0.940162</td>\n <td>-0.231053</td>\n <td>-0.775328</td>\n <td>0.295075</td>\n <td>0.561626</td>\n <td>-0.791815</td>\n <td>0.785382</td>\n <td>0.492755</td>\n <td>-0.182943</td>\n <td>206.0</td>\n </tr>\n <tr>\n <th>4</th>\n <td>0.113253</td>\n <td>-0.940162</td>\n <td>-0.768221</td>\n <td>0.459459</td>\n <td>0.117576</td>\n <td>0.358049</td>\n <td>0.210704</td>\n <td>-0.035628</td>\n <td>-0.661884</td>\n <td>-0.966116</td>\n <td>135.0</td>\n </tr>\n </tbody>\n</table>\n</div>"
}, },
"execution_count": 3, "execution_count": 4,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -183,16 +183,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 10,
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"X_train: (353, 10)\n", "X_train: (344, 10)\n",
"X_test: (89, 10)\n", "X_test: (86, 10)\n",
"y_train: (353,)\n", "y_train: (344,)\n",
"y_test: (89,)\n" "y_test: (86,)\n"
] ]
} }
], ],
@ -201,10 +201,7 @@
"\n", "\n",
"normalize = True\n", "normalize = True\n",
"# Load the data\n", "# Load the data\n",
"# X_train, X_test, y_train, y_test, df = load_data(normalize=normalize)\n", "X_train, X_test, y_train, y_test, df = load_data(normalize=normalize)\n",
"data = load_diabetes(scaled=True)\n",
"# split data into train and test sets\n",
"X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2)\n",
"print(\"X_train:\", X_train.shape)\n", "print(\"X_train:\", X_train.shape)\n",
"print(\"X_test:\", X_test.shape)\n", "print(\"X_test:\", X_test.shape)\n",
"print(\"y_train:\", y_train.shape)\n", "print(\"y_train:\", y_train.shape)\n",
@ -225,7 +222,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 11,
"outputs": [], "outputs": [],
"source": [ "source": [
"# Fit the linear regression\n", "# Fit the linear regression\n",
@ -238,7 +235,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 12,
"outputs": [], "outputs": [],
"source": [ "source": [
"# Fit the ridge regression\n", "# Fit the ridge regression\n",
@ -251,7 +248,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 13,
"outputs": [], "outputs": [],
"source": [ "source": [
"# Fit the lasso regression\n", "# Fit the lasso regression\n",
@ -273,54 +270,54 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 14,
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Linear Regression MSE train: 2886.03 test: 2801.42\n", "Linear Regression MSE train: 2983.19 test: 2487.62\n",
"Ridge Regression MSE train: 3417.29 test: 3129.15\n", "Ridge Regression MSE train: 2983.79 test: 2495.37\n",
"Lasso Regression MSE train: 2905.98 test: 2812.49\n", "Lasso Regression MSE train: 2983.19 test: 2487.98\n",
"Linear Regression RMSE train: 53.72 test: 52.93\n", "Linear Regression RMSE train: 54.62 test: 49.88\n",
"Ridge Regression RMSE train: 58.46 test: 55.94\n", "Ridge Regression RMSE train: 54.62 test: 49.95\n",
"Lasso Regression RMSE train: 53.91 test: 53.03\n", "Lasso Regression RMSE train: 54.62 test: 49.88\n",
"Linear Regression R2 train: 0.51 test: 0.54\n", "Linear Regression R2 train: 0.50 test: 0.49\n",
"Ridge Regression R2 train: 0.42 test: 0.49\n", "Ridge Regression R2 train: 0.50 test: 0.49\n",
"Lasso Regression R2 train: 0.51 test: 0.54\n", "Lasso Regression R2 train: 0.50 test: 0.49\n",
"Linear Regression features sorted by their coefficients:\n", "Linear Regression features sorted by their coefficients:\n",
"s1: -822.75\n", "s5: 28.62\n",
"s5: 765.40\n", "bmi: 23.95\n",
"bmi: 514.60\n", "s1: -23.18\n",
"s2: 424.92\n", "bp: 17.64\n",
"bp: 355.26\n", "s2: 14.38\n",
"sex: -241.22\n", "sex: -12.17\n",
"s4: 230.86\n", "s4: 4.23\n",
"s3: 129.59\n", "s3: -4.03\n",
"s6: 40.86\n", "s6: 2.06\n",
"age: -10.93\n", "age: 0.86\n",
"Ridge Regression features sorted by their coefficients:\n", "Ridge Regression features sorted by their coefficients:\n",
"bmi: 283.92\n", "s5: 26.16\n",
"s5: 238.87\n", "bmi: 23.91\n",
"bp: 195.46\n", "bp: 17.60\n",
"s3: -142.91\n", "s1: -16.67\n",
"s4: 106.92\n", "sex: -12.13\n",
"s6: 93.55\n", "s2: 9.11\n",
"sex: -63.98\n", "s3: -6.64\n",
"s2: -30.82\n", "s4: 3.73\n",
"age: 24.08\n", "s6: 2.12\n",
"s1: 0.97\n", "age: 0.87\n",
"Lasso Regression features sorted by their coefficients:\n", "Lasso Regression features sorted by their coefficients:\n",
"bmi: 528.98\n", "s5: 28.48\n",
"s5: 494.25\n", "bmi: 23.95\n",
"bp: 348.14\n", "s1: -22.80\n",
"sex: -230.35\n", "bp: 17.64\n",
"s3: -197.76\n", "s2: 14.07\n",
"s2: -137.48\n", "sex: -12.17\n",
"s4: 127.48\n", "s4: 4.19\n",
"s1: -101.61\n", "s3: -4.18\n",
"s6: 40.95\n", "s6: 2.06\n",
"age: -7.39\n", "age: 0.86\n",
"Linear Regression number of non-zero coefficients: 11\n", "Linear Regression number of non-zero coefficients: 11\n",
"Ridge Regression number of non-zero coefficients: 11\n", "Ridge Regression number of non-zero coefficients: 11\n",
"Lasso Regression number of non-zero coefficients: 11\n" "Lasso Regression number of non-zero coefficients: 11\n"