ML_course/assignment 3/iml_assignment3_solved.ipynb

393 lines
190 KiB
Plaintext
Raw Permalink Normal View History

2023-04-28 11:36:06 +00:00
{
"cells": [
{
"cell_type": "markdown",
"source": [
"### Solution for Assignment 3 of the course \"Introduction to Machine Learning\" at the University of Leoben.\n",
"##### Author: Fotios Lygerakis\n",
"##### Semester: SS 2022/2023"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"Create an abstract class “ContinuousDistribution”. The class must contain the following function definitions (not implementations)\n",
"* Data Import and Export using csv files.\n",
"* Computation of the mean based on the samples from the csv.\n",
"* Computation of the standard deviation based on the samples from the csv.\n",
"* Visualization of the distribution, the raw data or the generated samples.\n",
"* Generating/Drawing Samples from the distribution.\n"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"import abc\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from scipy.stats import multivariate_normal\n",
"from scipy.stats import beta"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class ContinuousDistribution(abc.ABC):\n",
" def __init__(self, name):\n",
" self.name = name\n",
"\n",
" @abc.abstractmethod\n",
" def import_data(self, filename):\n",
" pass\n",
"\n",
" @abc.abstractmethod\n",
" def export_data(self, filename):\n",
" pass\n",
"\n",
" @abc.abstractmethod\n",
" def compute_mean(self):\n",
" pass\n",
"\n",
" @abc.abstractmethod\n",
" def compute_std(self):\n",
" pass\n",
"\n",
" @abc.abstractmethod\n",
" def visualize(self):\n",
" pass\n",
"\n",
" @abc.abstractmethod\n",
" def generate_samples(self, n):\n",
" pass"
]
},
{
"cell_type": "markdown",
"source": [
"Implement a class “GaussDistribution”, which implements a multivariate Gaussian distribution (Equation 2.6 in the book).\n",
"* It is a child class of “ContinousDistribution”.\n",
"* Implement the functions defined in “ContinousDistribution”.\n",
"* Implement a constructor that optionally takes the dimension of the multivariate distribution.\n",
"* Implement a visualization for Multivariate Gaussians up to 3 dimensions.\n",
"* Find the empirical parameters of the distribution that created the samples in the MGD.csv file.\n",
"* Plot the samples of the MGD.csv file and the sample from the learned distribution in two subfigures.\n",
"* Generate visualizations of one- and two-dimensional Gaussians\n",
"\n",
"**The actual mean and covariance matrix of the MGD.csv file are [1, 0, 1] and [[1, 0.5, 0.5], [0.5, 1, 0.5], [0.5, 0.5, 1]] respectively.**\n"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"class GaussDistribution(ContinuousDistribution):\n",
" def __init__(self, name, dim=1, mean=None, cov=None):\n",
" super().__init__(name)\n",
" self.dim = dim\n",
" self.mean = mean\n",
" self.cov = cov\n",
" self.data = None\n",
" self.sampled_data = None\n",
"\n",
" def import_data(self, filename):\n",
" self.data = pd.read_csv(filename)\n",
"\n",
" def export_data(self, filename):\n",
" self.data.to_csv(filename, index=False)\n",
"\n",
" def compute_mean(self):\n",
" return self.data.mean()\n",
"\n",
" def compute_std(self):\n",
" # return the covariance matrix if the distribution is multivariate\n",
" if self.dim > 1:\n",
" return self.data.cov()\n",
" return self.data.std()\n",
"\n",
" def visualize(self):\n",
" # Plot the samples of the MGD.csv file and the sample from the learned distribution in two separate sub-figures. add a title to each sub-figure.\n",
" if self.dim == 1:\n",
" data = self.sampled_data if self.sampled_data is not None else self.data\n",
" # plot the samples\n",
" plt.hist(data['x'], density=True, histtype='stepfilled', alpha=0.2)\n",
"\n",
" # create a grid of points\n",
" x = np.linspace(data['x'].min(), data['x'].max(), 100)\n",
"\n",
" # plot the pdf of the distribution\n",
" plt.plot(x, multivariate_normal.pdf(x, self.mean, self.cov))\n",
"\n",
" plt.show()\n",
" elif self.dim == 2:\n",
" data = self.sampled_data if self.sampled_data is not None else self.data\n",
" # create a grid of points\n",
" x = np.linspace(data['x'].min(), data['x'].max(), 100)\n",
" y = np.linspace(data['y'].min(), data['y'].max(), 100)\n",
"\n",
" # plot the pdf of the distribution\n",
" X, Y = np.meshgrid(x, y)\n",
" pos = np.empty(X.shape + (2,))\n",
" pos[:, :, 0] = X\n",
" pos[:, :, 1] = Y\n",
" rv = multivariate_normal(self.mean, self.cov)\n",
" plt.contourf(X, Y, rv.pdf(pos))\n",
" # plot the samples\n",
" plt.scatter(data['x'], data['y'], c='r', marker='o')\n",
" plt.show()\n",
" elif self.dim == 3:\n",
" fig = plt.figure()\n",
" ax = fig.add_subplot(121, projection='3d')\n",
" ax.set_title('Original Data')\n",
" ax.scatter(self.data['x'], self.data['y'], self.data['z'], c='r', marker='o')\n",
" ax.set_xlabel('X Label')\n",
" ax.set_ylabel('Y Label')\n",
" ax.set_zlabel('Z Label')\n",
" ax = fig.add_subplot(122, projection='3d')\n",
" ax.set_title('Sampled Data')\n",
" ax.scatter(self.sampled_data['x'], self.sampled_data['y'], self.sampled_data['z'], c='b', marker='o')\n",
" ax.set_xlabel('X Label')\n",
" ax.set_ylabel('Y Label')\n",
" ax.set_zlabel('Z Label')\n",
" plt.show()\n",
"\n",
" def generate_samples(self, n):\n",
" # generate samples from the distribution\n",
" if self.dim == 1:\n",
" self.sampled_data = pd.DataFrame(np.random.normal(self.mean, self.cov, n), columns=['x'])\n",
" elif self.dim == 2:\n",
" self.sampled_data = pd.DataFrame(np.random.multivariate_normal(self.mean, self.cov, n), columns=['x', 'y'])\n",
" elif self.dim == 3:\n",
" self.sampled_data = pd.DataFrame(np.random.multivariate_normal(self.mean, self.cov, n), columns=['x', 'y', 'z'])\n",
"\n",
" def fit(self):\n",
" self.mean = self.compute_mean()\n",
" self.cov = self.compute_std()"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"Find the empirical parameters of the distribution that created the samples in the MGD.csv file."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"data": {
"text/plain": "<Figure size 640x480 with 2 Axes>",
2023-05-16 13:02:37 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAhkAAAEYCAYAAADxruSzAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd3zc9P3/n5Juee84doazJ9k7jLBngJZVWuAXKG2BMkqhfNtCy6aU0TaltLTQAi20QKHMBsIMZSSQ6Thxth078Yj3vC3p8/tD1vnOe5wdG/R8PPJIcj5LOp301uvznpIQQmBhYWFhYWFhEWXko30AFhYWFhYWFl9NLJFhYWFhYWFhMSBYIsPCwsLCwsJiQLBEhoWFhYWFhcWAYIkMCwsLCwsLiwHBEhkWFhYWFhYWA4IlMiwsLCwsLCwGBEtkWFhYWFhYWAwIlsiwsLCwsLCwGBAskTEI3H333UiS1KffffbZZ5EkiaKiougeVBhFRUVIksSzzz47YPuwsLAYWkiSxN133x217X388cdIksTHH38ctW1aDH8skdEF+fn5XH755YwaNQqn00l2djaXXXYZ+fn5R/vQjgqmETH/OJ1OMjMzOfHEE/nVr35FVVVVn7e9a9cu7r777gEVUxYW0WbHjh1cdNFF5OTk4HK5GDVqFKeddhp/+MMfjvahDVnMhZP5x+VykZ2dzRlnnMFjjz1GU1NTn7e9fv167r77burr66N3wBb9whIZnfDqq68yf/58PvzwQ6666ir+9Kc/cfXVV7Nu3Trmz5/Pa6+91uNt/eIXv8Dr9fbpOK644gq8Xi85OTl9+v2B4KabbuK5557jySef5LbbbiM1NZW77rqL6dOn89FHH/Vpm7t27eKee+6xRIbFsGH9+vUsXLiQ7du38/3vf5/HH3+c733ve8iyzO9///ujfXhDnnvvvZfnnnuOJ554ghtvvBGAm2++mVmzZpGXl9enba5fv5577rnHEhlDCNvRPoChSEFBAVdccQUTJkzgk08+ISMjI/SzH/3oRxx//PFcccUV5OXlMWHChE6343a7iYuLw2azYbP17VQrioKiKH363YHi+OOP56KLLop4bfv27Zx++ulceOGF7Nq1i6ysrKN0dBYWg8MDDzxAUlISmzZtIjk5OeJnlZWVR+eghhFnnXUWCxcuDP3/5z//OR999BErV67kvPPOY/fu3cTExBzFI7SIBpYnowMeeeQRPB4PTz75ZITAAEhPT+cvf/kLbrebhx9+OPS6mXexa9cuvvOd75CSksJxxx0X8bNwvF4vN910E+np6SQkJHDeeedRWlraLk7aUU7GuHHjWLlyJZ999hmLFy/G5XIxYcIE/vGPf0Tso7a2lp/85CfMmjWL+Ph4EhMTOeuss9i+fXuUzlQrc+bMYfXq1dTX1/P444+HXi8uLuaHP/whU6dOJSYmhrS0NC6++OKIz/Pss89y8cUXA3DSSSeF3KhmbPeNN97gnHPOITs7G6fTycSJE7nvvvvQNC3qn8PCoqcUFBQwc+bMdgIDYMSIERH/f+aZZzj55JMZMWIETqeTGTNm8MQTT7T7PfPe/vjjj1m4cCExMTHMmjUrdC+8+uqrzJo1C5fLxYIFC9i2bVvE71955ZXEx8dTWFjIGWecQVxcHNnZ2dx77730ZOB2aWkp3/3ud8nMzMTpdDJz5kyefvrpdu8rKSnhG9/4BnFxcYwYMYIf//jH+P3+brffHSeffDK//OUvKS4u5vnnnw+9npeXx5VXXsmECRNwuVyMHDmS7373u9TU1ITec/fdd3PbbbcBMH78+JAdMW1NT78Di+hieTI64K233mLcuHEcf/zxHf78hBNOYNy4caxZs6bdzy6++GImT57Mr371qy5v6iuvvJJ///vfXHHFFSxdupT//e9/nHPOOT0+xgMHDnDRRRdx9dVXs2rVKp5++mmuvPJKFixYwMyZMwEoLCzk9ddf5+KLL2b8+PFUVFTwl7/8hRUrVrBr1y6ys7N7vL+eYB7Pe++9xwMPPADApk2bWL9+PZdeeimjR4+mqKiIJ554ghNPPJFdu3YRGxvLCSecwE033cRjjz3G7bffzvTp0wFCfz/77LPEx8dzyy23EB8fz0cffcSdd95JY2MjjzzySFQ/g4VFT8nJyWHDhg3s3LmTY445psv3PvHEE8ycOZPzzjsPm83GW2+9xQ9/+EN0Xef666+PeO+BAwf4zne+wzXXXMPll1/Oo48+yrnnnsuf//xnbr/9dn74wx8C8OCDD3LJJZewd+9eZLl1vahpGmeeeSZLly7l4YcfZu3atdx1112oqsq9997b6TFWVFSwdOlSJEnihhtuICMjg3feeYerr76axsZGbr75ZsBYIJ1yyikcOnSIm266iezsbJ577rk+h0rbcsUVV3D77bfz3nvv8f3vfx+A999/n8LCQq666ipGjhxJfn4+Tz75JPn5+XzxxRdIksQFF1zAvn37eOGFF/jd735Heno6QGih2JvvwCKKCIsI6uvrBSDOP//8Lt933nnnCUA0NjYKIYS46667BCC+/e1vt3uv+TOTLVu2CEDcfPPNEe+78sorBSDuuuuu0GvPPPOMAMTBgwdDr+Xk5AhAfPLJJ6HXKisrhdPpFLfeemvoNZ/PJzRNi9jHwYMHhdPpFPfee2/Ea4B45plnuvzM69atE4B4+eWXO33PnDlzREpKSuj/Ho+n3Xs2bNggAPGPf/wj9NrLL78sALFu3bp27+9oG9dcc42IjY0VPp+vy2O2sBgo3nvvPaEoilAURSxbtkz83//9n3j33XdFIBBo996OruEzzjhDTJgwIeI1895ev3596LV3331XACImJkYUFxeHXv/LX/7S7p5ZtWqVAMSNN94Yek3XdXHOOecIh8MhqqqqQq+3tTVXX321yMrKEtXV1RHHdOmll4qkpKTQZ1i9erUAxL///e/Qe9xut5g0aVKn93A4pk3btGlTp+9JSkoS8+bNC/2/o/P3wgsvtLODjzzySDt72dU2OvoOLKKLFS5pg5nZnJCQ0OX7zJ83NjZGvH7ttdd2u4+1a9cChFYkJmbyU0+YMWNGhKclIyODqVOnUlhYGHrN6XSGVjiaplFTU0N8fDxTp05l69atPd5Xb4iPj4/IDg+PqQaDQWpqapg0aRLJyck9PobwbTQ1NVFdXc3xxx+Px+Nhz5490Tt4C4tecNppp7FhwwbOO+88tm/fzsMPP8wZZ5zBqFGjePPNNyPeG34NNzQ0UF1dzYoVKygsLKShoSHivTNmzGDZsmWh/y9ZsgQwQgljx45t93r4PW9yww03hP5teiYCgQAffPBBh59FCMF//vMfzj33XIQQVFdXh/6cccYZNDQ0hO7Xt99+m6ysrIi8rNjYWH7wgx90fcJ6QVd2xOfzUV1dzdKlSwH6ZEe6+w4sooclMtpgiofuyqg6EyPjx4/vdh/FxcXIstzuvZMmTerxcYYbG5OUlBTq6upC/9d1nd/97ndMnjwZp9NJeno6GRkZ5OXlDdhN1dzcHHFOvF4vd955J2PGjIk4hvr6+h4fQ35+Pt/85jdJSkoiMTGRjIwMLr/8cgDLOFgcVRYtWsSrr75KXV0dGzdu5Oc//zlNTU1cdNFF7Nq1K/S+zz//nFNPPZW4uDiSk5PJyMjg9ttvB9pfw23v7aSkJADGjBnT4evh9zyALMvtEtKnTJkC0Gn1VlVVFfX19aE8tPA/V111FdCazFpcXMykSZPa5ZlNnTq1w233hbZ2pLa2lh/96EdkZmYSExNDRkZGyH721Ab05juwiB5WTkYbkpKSyMrK6raEKi8vj1GjRpGYmBjx+mBlQ3dWcSLC8kB+9atf8ctf/pLvfve73HfffaSmpiLLMjfffDO6rkf9mILBIPv27YuIT994440888wz3HzzzSxbtoykpCQkSeLSSy/t0THU19ezYsUKEhMTuffee5k4cSIul4utW7fy05/+dEA+h4VFb3E4HCxatIhFixYxZcoUrrrqKl5++WXuuusuCgoKOOWUU5g2bRq//e1vGTNmDA6Hg7fffpvf/e537a7hzu7tntzzfcU8hss
2023-04-28 11:36:06 +00:00
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"x 0.984026\n",
"y 0.008833\n",
"z 0.860115\n",
"dtype: float64\n",
" x y z\n",
"x 0.957077 0.652484 0.563223\n",
"y 0.652484 1.420160 0.753830\n",
"z 0.563223 0.753830 1.255599\n"
]
}
],
"source": [
"gd = GaussDistribution('MGD', 3)\n",
"gd.import_data('MGD.csv')\n",
"gd.fit()\n",
"gd.generate_samples(100)\n",
"gd.visualize()\n",
"# print the mean and covariance matrix of the distribution\n",
"print(gd.mean)\n",
"print(gd.cov)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"Generate visualizations of one- and two-dimensional Gaussians using the above class."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [
{
"data": {
"text/plain": "<Figure size 640x480 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABCGElEQVR4nO3deXhU5cHG4d/MJJksZCUbhEBYVEQ0IEsEVFyiuFbq8uEKUrWtCy6pVXCBoiIuFLGIolatrVWwVdzFJQVRQVEQBAUUFLIxWSBM9m1mvj+GRqksGUjyzvLc1zVXT07PmXmmUzJPznnPeywej8eDiIiIiCFW0wFEREQktKmMiIiIiFEqIyIiImKUyoiIiIgYpTIiIiIiRqmMiIiIiFEqIyIiImKUyoiIiIgYFWY6QFu43W5KSkqIjY3FYrGYjiMiIiJt4PF4qK6upnv37lit+z7+ERBlpKSkhMzMTNMxRERE5CAUFhbSo0ePff73AVFGYmNjAe+biYuLM5xGRERE2qKqqorMzMzW7/F9OagyMm/ePB5++GEcDgfZ2dnMnTuX4cOH73Xbv/3tb0ycOHGPdXa7nYaGhja/3n9PzcTFxamMiIiIBJgDDbHweQDrwoULycvLY9q0aaxevZrs7GzGjBlDWVnZPveJi4tj+/btrY9t27b5+rIiIiISpHwuI7Nnz+aaa65h4sSJDBgwgPnz5xMdHc2zzz67z30sFgvp6emtj7S0tEMKLSIiIsHDpzLS1NTEqlWryM3N/ekJrFZyc3NZsWLFPverqamhV69eZGZmct555/HNN9/s93UaGxupqqra4yEiIiLByacyUlFRgcvl+sWRjbS0NBwOx173OeKII3j22Wd5/fXXeeGFF3C73YwcOZKioqJ9vs7MmTOJj49vfehKGhERkeDV4ZOejRgxgvHjxzNo0CBGjx7Nq6++SkpKCk8++eQ+95kyZQpOp7P1UVhY2NExRURExBCfrqZJTk7GZrNRWlq6x/rS0lLS09Pb9Bzh4eEMHjyYzZs373Mbu92O3W73JZqIiIgEKJ+OjERERDBkyBDy8/Nb17ndbvLz8xkxYkSbnsPlcrFu3Tq6devmW1IREREJSj7PM5KXl8eECRMYOnQow4cPZ86cOdTW1rbOJTJ+/HgyMjKYOXMmAPfccw/HHXcc/fr1Y9euXTz88MNs27aNq6++un3fiYiIiAQkn8vIuHHjKC8vZ+rUqTgcDgYNGsTixYtbB7UWFBTsMf98ZWUl11xzDQ6Hg8TERIYMGcLy5csZMGBA+70LERERCVgWj8fjMR3iQKqqqoiPj8fpdGoGVhERkQDR1u/vDr+aRkRERGR/VEZERETEKJURERERMUplRERERIzy+WoaEZH2tq7IaTqCz47uEW86gkjQ0JERERERMUplRERERIxSGRERERGjVEZERETEKJURERERMUplRERERIxSGRERERGjVEZERETEKJURERERMUplRERERIxSGRERERGjVEZERETEKJURERERMUplRERERIxSGRERERGjVEZERETEKJURERERMUplRERERIxSGRERERGjVEZERETEKJURERERMUplRERERIxSGRERERGjVEZERETEKJURERERMUplRERERIxSGRERERGjVEZERETEKJURERERMUplRERERIxSGRERERGjVEZERETEKJURERERMUplRERERIxSGRERERGjVEZERETEKJURERERMUplRERERIxSGRERERGjVEZERETEKJURERERMUplRERERIxSGRERERGjVEZERETEKJURERERMUplRERERIxSGRERERGjVEZERETEKJURERERMUplRERERIxSGRERERGjVEZERETEKJURERERMUplRERERIxSGRERERGjVEZERETEKJURERERMeqgysi8efPIysoiMjKSnJwcVq5c2ab9FixYgMViYezYsQfzsiIiIhKEfC4jCxcuJC8vj2nTprF69Wqys7MZM2YMZWVl+91v69at3HrrrZxwwgkHHVZERESCj89lZPbs2VxzzTVMnDiRAQMGMH/+fKKjo3n22Wf3uY/L5eKyyy5j+vTp9OnT55ACi4iISHAJ82XjpqYmVq1axZQpU1rXWa1WcnNzWbFixT73u+eee0hNTeWqq67i448/PuDrNDY20tjY2PpzVVWVLzFFJBh5PITVOQiv3Y61uQ5rS93u/6zHHR5DS2QirsgkWiKTcEUm4bHZTScWkTbyqYxUVFTgcrlIS0vbY31aWhobN27c6z6ffPIJzzzzDGvWrGnz68ycOZPp06f7Ek1EgonHQ+TOjcRs/5TInd8Rues77JXfY2uubtvuFhsNCYfRkDyQ+uSjqe86kPqUY1RQRPyUT2XEV9XV1VxxxRU8/fTTJCcnt3m/KVOmkJeX1/pzVVUVmZmZHRFRRPyEtamGLsUfE1u0hNjCpYTXOX6xjccSRnNMOu7wGO8jLAq3LRJrSx1hDTuxNezE1rgLq7uZqMqNRFVuJPH7fwPgCu9CdebJVGWdQXWPk3FHdOnkdygi++JTGUlOTsZms1FaWrrH+tLSUtLT03+x/ZYtW9i6dSvnnntu6zq32+194bAwNm3aRN++fX+xn91ux27XXzAiIaHkKzI+nk/CltexttS3rnbbIqnpPpL65GNoTDyMhsTDaYrrjccWsf/n2306J6pi3U+P8q8Jb6gg4Yc3SfjhTdw2OzUZJ7DzyMup7nESWDTLgYhJPpWRiIgIhgwZQn5+fuvluW63m/z8fG644YZfbN+/f3/WrVu3x7q77rqL6upqHn30UR3tEAlVzfWw7l/w5bNQ8hVJu1c3xvWiOvNUqjNPpjY9B09YpO/PbbHQEtON6phuVPc63bvO4yaqfC3xW98lbut72Kt+JK7gQ+IKPqQxvg8VR/2GXYddgDs8pt3eooi0nc+nafLy8pgwYQJDhw5l+PDhzJkzh9raWiZOnAjA+PHjycjIYObMmURGRjJw4MA99k9ISAD4xXoRCQGuZlj9d1j2MFRv966zRVCZdTY7j7yMurRhYLG0/+tarNSnDqY+dTCOYVOwV35H4ncvk7RpAXbnD2Qsv4v0Lx9ix5HjKc/+Pe6IuPbPICL75HMZGTduHOXl5UydOhWHw8GgQYNYvHhx66DWgoICrFYd8hSRn3G7vEdCls6Eyq3edXE9IOd3MOgyiio7dPjaniwWGpOOwHHc3ZQdewuJ3/+brt88i71qK6lrHyNp04uUDb6ZHUdeBtbwzsslEsIsHo/HYzrEgVRVVREfH4/T6SQuTn+xiASUgs/hrVug7BvvzzEpcOIfYciVEOYdG7auyGkuH4DHTdy290n74kEinVsAaIzrjWP4FKp6jdnr0Zqje8R3dkqRgNPW7+9O/HNEREJKQxXkT4cvngE8EBkPo26CnN9DhJ+NzbBYqco6g6qep5K0cQGpqx/BXvUjvT78LVWZp1B8/AO0xPxykL6ItA+dTxGR9rfhLZiXA1/8FfDAoMvgxjVwwh/8r4j8nDWcnQOu4Lv/W0bZoEm4rRHEFf6Hw1/JJeG7f4H/H0gWCUgqIyLSfhqr4ZWrYeFlUF0Cib1h/Osw9nGITjrw/n7CHdGF0qF/ZPOv36EuZRC2pioyl/2BrPeuJKx2u+l4IkFHZURE2sf2r+HJ0d6BqhYbHH8LXLcC+pxkOtlBa0w8nC3nvsr2YVNw2+zEFi3h8FdOI7Yg33Q0kaCiMiIih8bj8Z6O+Wsu7NwCcRkw8R3I/ROER5lOd+isYVRkX8vmse9Ql5KNramKrPcnwn9meK8SEpFDpjIiIgevsQb+PRHe/gO4GuHwM+D3n0DP40wna3eNiYfxwzmvUDHgSu+KZQ/BCxdA7Q6juUSCgcqIiBwcZzE8dwZ8swisYXD6DLhkQUCNDfGVxxbB9pH3UHDSXyA8Gn5YAk+eCMWrTUcTCWgqIyLiu5I18NdTwbEOopPhyndg5A0dM3uqH3L2GwvX/Ae69oOqIvjb2bBpselYIgFLZUREfLPxbXjuTO907in94Zp86JljOlXnSz0SrlkCfU+F5jpYcIn3Xjsi4jOVERFpu8/
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": "<Figure size 640x480 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAi8AAAGdCAYAAADaPpOnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABuhklEQVR4nO3deXxUVZ43/k9dMAkkqYoxlQ0D2XjAAGHtCAgtYTGKG47t04z2CI7SjQ06NjgDtLaALTJMo+0otOg4LfgoP7qnbXCBFmNABCHNsMSwRkMIYMhqSFUSJMHc+v1R3KSqUsutqrucc+v7fr2ioXKTnCRV937uOd9zjsnhcDhACCGEEMIJQe8GEEIIIYQEg8ILIYQQQrhC4YUQQgghXKHwQgghhBCuUHghhBBCCFcovBBCCCGEKxReCCGEEMIVCi+EEEII4UpfvRugNFEUcfHiRcTHx8NkMundHEIIIYTI4HA40NraivT0dAiC/74Vw4WXixcvIiMjQ+9mEEIIISQEFy5cwI033uj3GMOFl/j4eADAlKSH0VeI0rk1RG+d/ydd7yYQxgxvqcbq4+8EPG7Z8IdxPCFT/QZFiKivL+rdBMK4H8ROfN70Tvd13B/DhRdpqKivEEXhJYJ1DnWmdsM9wUnYvr7h/+BKtBlJHXavRX8igKZoC76+4f+gr4nKApUi5mUDAKJOf6tzSwjr5JR80CuTGErn0Bu7gwsh3ogmAetz74IJzqDi9jEAJgDrc++ESMFFFfQaJUqgVycxBDohkmDssw7HimEPoina7PZ4U7QFK4Y9iH3W4Tq1LHLQa5aEg3rVCdfo5EdCtc86HPuT8jCipRqJna1ojorHsYRM6nHRmPQapuEkEgwKL4RLFFqIEkSTgK+uz9a7GQTO1zQFGCIX3WIQ7lBwIcSYaCiJyEU9L4QbdFLjiy0nWu8m9GI506F3E4gMNJREAqHwQphHoYUtLIYSueS0nQIOOyjEEF8ovBCmUXDRD88hJRz+fm4KNvqgehjiicILYRKFFu1EakgJha/fFYUa9VEvDHFF4YUwhUKLuiioqMPb75UCjTqoF4YAFF4IQyi4KI/Cin48f/cUZpRDvTCEwgvRHYUW5VBYYReFGeVRiIlcFF6Irii4hIfCSmgEUcTo+iokXbajqb8ZR1OyIQraLntFYUY5NJQUeSi8EF1QaAkNhZXwFVaX4+nSbUhpt3U/Vh9rwdrxs7A7M1+3drn+bSnIBI96YSILrbBLNEfBJTi2nOjuNxKewupyrCnZBKtLcAEAa7sNa0o2obC6XKeWuXP9m0t/d8EhYuSlKhTWf4WRl6ogODz3xCYAnV8iharh5YsvvsDdd9+N9PR0mEwmbNu2ze/xn3/+OUwmU6+3uro6NZtJNEJLf8tHgUV5giji6dJtzvc9P3bt/4tLP4AgshcKxvSpwLuHfoeXv3oLz576E17+6i28V/ofmNR4XO+mMYnOM8ananhpb2/HyJEjsX79+qA+r6KiArW1td1vycnJKrWQaIVOJoFRYFHX6PoqpLTbfJ70BACp7S0YXV+lZbMC8tVblNRhx4oTmynA+EA3S8amas3LHXfcgTvuuCPoz0tOTkZCQoLyDSK6oBOIbxRUtJN02a7ocVoI1FskAni8egf2J+VBNFEVgDdUzGtMTD7bR40ahbS0NMyYMQNffvml32M7Ojpgt9vd3ggb6M7HN+ph0V5Tf7Oix2lBbm9RdmwNPaf8oHOR8TAVXtLS0rBhwwa8//77eP/995GRkYEpU6bgyJEjPj9n9erVsFgs3W8ZGRkatpj4QieK3mhYSF9HU7JRH2uBr4oWEUBdbAKOpmRr2Sy/QuktoueZb3ReMg6mwsuQIUPwi1/8AmPHjsXEiRPxxz/+ERMnTsTvf/97n5+zbNky2Gy27rcLFy5o2GLiDZ0g3NGFhA2iIGDt+FnO9z0/du3/L42/V/P1XvwJt7eInnu90fnJGNh5lfpQUFCAyspKnx+Pjo6G2Wx2eyP6oK7ZHnT3y6bdmflYMm0OGmMtbo83xCZgybQ5uq7z4o1SvUX0XHRH5yn+Mb9IXVlZGdLS0vRuBgmATgZOdIFg3+7MfOwZOFz3FXblkHqL1pRsggj3u81QeotoIbwetKgd31QNL21tbW69JmfPnkVZWRkSExMxcOBALFu2DDU1NXjnnXcAAK+88gqysrIwbNgwXLlyBW+99RZ27dqFTz/9VM1mkjBRcKHQwhtREHA4LVfvZsgi9RZ5rgrcEJuAl8bfG3JvkfScpRBDs5F4pGp4OXToEAoLC7v/vWjRIgDAnDlzsHHjRtTW1uL8+fPdH+/s7MTixYtRU1OD/v37Iz8/H5999pnb1yBsifTgQqGFaEHN3iIKMRRgeGRyOBwOvRuhJLvdDovFgunJj6GvEKV3cwyLQguFFmJMkRxiKMDo6wexE581vAWbzRawfpX5mhfCnkgOLhRaiNFFck8M1cHwg70KNcK0SA0uNFuDRJpIfs5H6nmOJ9TzQmSLxBd0pJ68ldI6yKR3E7rFnzPUCLlmIrUnhupg2EbhhcgSacGFQktgLAUTOQK1l8KNf5EYYijAsIvCCwnIqMFFcIgY0VKNxM5WNEfF41hCJi7l9tO7WczhLaSEyt/PScGmhy0nmgIM0R2FF+KXUYPLpMbjWFD5MZI7evaEqY+1YG3fWcytsqqlSAkqwfL2e4nkQBNpvTAUYNhDBbvEKyMv9T+p8ThWnNgMa4f7pnfWdhvWlGxCYXW5Ti3TXusgk9sbkY9+d5FV1GvU8yGvqOeF9GLkF6ngELGg8mMAgOflRoBzyfXFpR9gz8DhTC4XH65IvchqwfN3G0k9M5HSE0M9MOyg8ELcGDm4AMCIlmq3oSJPAoDU9haMrq/iZvl4f7QMK4IoYty5Klhb7WiMN+PQIDb3C1KCnJ81EsNMJNTDUIBhA4UX0s3owQUA+l3/vazjki77Djis06N3ZcbJcjyzYyvS7D1779SaLVg18z4U5xmrhijUn9X172LkIBMJvTAUYPRnzNsiErRICC62nGg09fe/5LRE7nEs0Lv2YsbJcry6ZSNSXC7mAJBit+HVLRsx46RxaoiU+lkjoVbG6LUwkXDOZBmFF2L4F6FrUeHRlGzUx1og+jhWBFAXm4CjKdmatS9ULFz8BFHEMzu2Ot/3/Ni1//96xzYIoq/fOD/U+lmNHGSMXtBr9HMnyyi8RDijv/g8T5yiIGDt+FnO9z2Olf790vh7ma3VYO1CN+5cFdLsNp8nEgFAur0F485VadksVWjxs7L291UKBRiiNDbP0EQTRn7R+bvj252ZjyXT5qAx1uL2eENsApZMm8PcOi8sX9CsrfJqg+QexzKtf1aW/+6hMHIvjJHPpayigt0IZeQXm5wT5O7MfOwZOByj66uQdNmOpv5mHE1ha3YMDxetxnh5tUFyj2OZnj+r9FwwQqGvUWckURGvtii8RCCjBpdg7+pEQWBuOjQPgcXVoUHZqDVbkOJjOEUEUGdOwKFB7NcQBRLoZwWALpMJCe1tqrXBKDOWKMCQcLFzm6mwvM6LEBz8FwkSeXjvjuZ1eEAUBKyaeZ/zfc+PXfv/izNnMdWjFSrXn9VXbDA5HPjPP7+jyQwrXp8zEqMOIxn15pA1/J9RfPhtyw5sbPx/mHjljN5NYYoRX1g8nwB5vwABQHFePp6cPRf1ZvcaojpzAp6cPddQ67wU5+Xjqf87B6LJ+99MjxlWvD+HeH79+mLE8yxrDD1sdIPYjmdbduKFhCLsj8nRuzm6M9oLiueTHs8XG2+K8/JRMnR4RKyweyk2Fn0cvodsXGcdHczSbliS57oYIw4j0RCSugwdXqS9an5h/xKl0VkQTcY7kcpFwYUNRgstrkRB0PRirRfWZ1jxGmKMuDIvBRj1GP5qLgBIFtswrLNW76bohoKL/njv2ic9eJlhxetzjsfXtz9GO/+ywtA9L64Sxct6N0EXRnvh8HZi4/HiobaOgZ2qf4/o81GqfW3eZljx2BNjtGEk6oFRXsSEl2ahv95N0JyRgguFFj5oEUzkkNOOUAOONOvo1S0bIcK9+5rlGVa8hRg
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"gd = GaussDistribution('1D', dim=1, mean=0, cov=1)\n",
"gd.generate_samples(100)\n",
"gd.visualize()\n",
"\n",
"gd = GaussDistribution('2D', dim=2, mean=[0, 0], cov=[[1, 0], [0, 1]])\n",
"gd.generate_samples(100)\n",
"gd.visualize()"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"Implement a class “BetaDistribution” (Equation 2.5 in the book)\n",
"* It is also a child class of “ContinousDistribution”.\n",
"* Its purpose is to generate beta distributed samples and plot the distribution giving the parameters a and b.\n",
"* The contractor should take the parameters a and b as arguments.\n",
"* Implement a visualization for Beta distributions, including the mean and the standard deviation lines."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [],
"source": [
"class BetaDistribution(ContinuousDistribution):\n",
" def __init__(self, name, a, b):\n",
" super().__init__(name)\n",
" self.a = a\n",
" self.b = b\n",
" self.data = None\n",
" self.sampled_data = None\n",
"\n",
" def import_data(self, filename):\n",
" pass\n",
"\n",
" def export_data(self, filename):\n",
" self.sampled_data.to_csv(filename, index=False)\n",
"\n",
" def compute_mean(self):\n",
" pass\n",
"\n",
" def compute_std(self):\n",
" pass\n",
"\n",
" def visualize(self):\n",
" # Plot the distribution of the beta distribution\n",
" x = np.linspace(0, 1, 100)\n",
" y = beta.pdf(x, self.a, self.b)\n",
" plt.plot(x, y)\n",
" plt.title('Beta Distribution with a = ' + str(self.a) + ' and b = ' + str(self.b))\n",
" plt.xlabel('x')\n",
" plt.ylabel('y')\n",
" plt.show()\n",
"\n",
" def generate_samples(self, n):\n",
" self.sampled_data = pd.DataFrame(np.random.beta(self.a, self.b, n), columns=['x'])"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"Generate visualizations of Beta distributions using the above class."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": "<Figure size 640x480 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHHCAYAAABDUnkqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABelElEQVR4nO3deXhM9/4H8PeZSWayzkQS2YjIYhdB7EqofWt1wa8LulC13Ba91aq2Wm5pq0qrVHdKlVL0XlU7tYXWWlsQiQSRSEgy2ZeZ7++PyBAiksjMmeX9ep552pycM/OZIzJv31USQggQERER2QiF3AUQERER1SSGGyIiIrIpDDdERERkUxhuiIiIyKYw3BAREZFNYbghIiIim8JwQ0RERDaF4YaIiIhsCsMNERER2RSGG6IHIEkS3nvvPZO/zq5duyBJEnbt2mU81q1bNzRv3tzkrw0AFy9ehCRJWLJkiVleryZU5c9GkiRMmDDBtAWRUXk/z+V57733IEkS0tLSzFMY2QyGGzKZJUuWQJKkMg8fHx90794df/zxR7Wfd9GiRSb5kK1fv76xToVCAQ8PD4SHh+Oll17CwYMHa+x1VqxYgfnz59fY89UkS67tQe3fvx/vvfceMjIy5C7FZNauXYthw4YhJCQELi4uaNSoEV577TWbfs/mUt7vs9JHcnKy3OXRHRzkLoBs34wZMxAcHAwhBFJSUrBkyRL0798f//vf/zBw4MAqP9+iRYvg7e2N5557rsZrbdmyJV577TUAQFZWFs6cOYPVq1fjm2++waRJk/Dpp5+WOT8vLw8ODlX7a7RixQqcPHkSEydOrPQ1Xbt2RV5eHlQqVZVeq6ruVVtQUBDy8vLg6Oho0tevSXf+2ezfvx/vv/8+nnvuOXh4eMhXmAm99NJLCAgIwLPPPot69erhxIkT+OKLL7Bx40YcOXIEzs7Ocpdo9Up/n93OVn+erBnDDZlcv3790KZNG+PXL774Inx9ffHzzz9XK9yYUp06dfDss8+WOfbRRx/h6aefxrx589CgQQOMHTvW+D0nJyeT1pOfnw+VSgWFQmHy16qIJEmyvn51WFu9NWHNmjXo1q1bmWORkZEYOXIkfvrpJ4waNUqewmzInb/PyDKxW4rMzsPDA87Ozne1eBgMBsyfPx/NmjWDk5MTfH19MWbMGKSnpxvPqV+/Pk6dOoU///zT2CRc+sv8xo0b+Pe//43w8HC4ublBo9GgX79+OH78+APV6+zsjGXLlsHT0xMffPABhBDG7905riMrKwsTJ05E/fr1oVar4ePjg169euHIkSMASsbJ/P7770hISDDWX79+fQC3xiGsXLkSb7/9NurUqQMXFxfodLoKxygcPnwYnTp1grOzM4KDg7F48eIy3y9tTr948WKZ43c+Z0W13WvMzY4dO9ClSxe4urrCw8MDjz76KM6cOVPmnNJxE7GxscZWE61Wi+effx65ubkV3vvPP/8cSqWyTLfK3LlzIUkSJk+ebDym1+vh7u6ON954w3js9j+b9957D6+//joAIDg42Pj+7rwn69evR/PmzaFWq9GsWTNs2rSpwvoAoLCwEO+++y4iIyOh1Wrh6uqKLl26YOfOnfe9tqbdGWwA4LHHHgOAu/5cyvPbb79hwIABCAgIgFqtRmhoKGbOnAm9Xn/X6zRv3hynT59G9+7d4eLigjp16uDjjz++6zkvX76MwYMHw9XVFT4+Ppg0aRIKCgqq9L7S0tIwdOhQaDQaeHl54dVXX0V+fn6VnqMmZWVl3XVPyLKw5YZMLjMzE2lpaRBC4Nq1a1iwYAGys7PvaiEZM2YMlixZgueffx6vvPIK4uPj8cUXX+Do0aPYt28fHB0dMX/+fPzrX/+Cm5sbpk2bBgDw9fUFAMTFxWH9+vUYMmQIgoODkZKSgq+++gpRUVE4ffo0AgICqv0e3Nzc8Nhjj+G7777D6dOn0axZs3LPe/nll7FmzRpMmDABTZs2xfXr17F3716cOXMGrVu3xrRp05CZmYnLly9j3rx5xue+3cyZM6FSqfDvf/8bBQUFFXZFpaeno3///hg6dCieeuop/PLLLxg7dixUKhVeeOGFKr3HytR2u23btqFfv34ICQnBe++9h7y8PCxYsACdO3fGkSNHjMGo1NChQxEcHIzZs2fjyJEj+Pbbb+Hj44OPPvronq/RpUsXGAwG7N2719jKt2fPHigUCuzZs8d43tGjR5GdnY2uXbuW+zyPP/44zp07h59//hnz5s2Dt7c3AKB27drGc/bu3Yu1a9di3LhxcHd3x+eff44nnngCiYmJ8PLyumeNOp0O3377LZ566imMHj0aWVlZ+O6779CnTx/89ddfaNmy5T2vBYDs7OxKfVA7OjpCq9Xe97w7lY4HKX3PFVmyZAnc3NwwefJkuLm5YceOHXj33Xeh0+kwZ86cMuemp6ejb9++ePzxxzF06FCsWbMGb7zxBsLDw9GvXz8AJV2DPXr0QGJiIl555RUEBARg2bJl2LFjR5Xew9ChQ1G/fn3Mnj0bBw4cwOeff4709HT8+OOPFV6Xm5t73wANAEqlErVq1apULd27d0d2djZUKhX69OmDuXPnokGDBpW6lsxIEJnIDz/8IADc9VCr1WLJkiVlzt2zZ48AIH766acyxzdt2nTX8WbNmomoqKi7Xi8/P1/o9foyx+Lj44VarRYzZsy4b71BQUFiwIAB9/z+vHnzBADx22+/GY8BENOnTzd+rdVqxfjx4yt8nQEDBoigoKC7ju/cuVMAECEhISI3N7fc7+3cudN4LCoqSgAQc+fONR4rKCgQLVu2FD4+PqKwsFAIcevPIT4+/r7Pea/a4uPjBQDxww8/GI+Vvs7169eNx44fPy4UCoUYMWKE8dj06dMFAPHCCy+Uec7HHntMeHl53fVat9Pr9UKj0YgpU6YIIYQwGAzCy8tLDBkyRCiVSpGVlSWEEOLTTz8VCoVCpKenG6+9889mzpw55d6H0nNVKpWIjY0t814AiAULFlRYY3FxsSgoKChzLD09Xfj6+t71nsszcuTIcv+e3Pko72e+Ml588UWhVCrFuXPn7nvunT93QggxZswY4eLiIvLz843HSn/2fvzxR+OxgoIC4efnJ5544gnjsfnz5wsA4pdffjEey8nJEWFhYXf97JWn9GfnkUceKXN83LhxAoA4fvx4pa6/36O8n/k7rVq1Sjz33HNi6dKlYt26deLtt98WLi4uwtvbWyQmJt73ejIvttyQyS1cuBANGzYEAKSkpGD58uUYNWoU3N3d8fjjjwMAVq9eDa1Wi169epWZ9hkZGQk3Nzfs3LkTTz/9dIWvo1arjf+v1+uRkZEBNzc3NGrUyNgt9CBKWzGysrLueY6HhwcOHjyIpKSkarcUjRw5stIDPx0cHDBmzBjj1yqVCmPGjMHYsWNx+PBhdOjQoVo13M/Vq1dx7NgxTJkyBZ6ensbjLVq0QK9evbBx48a7rnn55ZfLfN2lSxesW7cOOp0OGo2m3NdRKBTo1KkTdu/eDaCka+X69et488038euvvyI6Ohq9evXCnj170Lx58wca2NmzZ0+EhoaWeS8ajQZxcXEVXqdUKqFUKgGUdK1mZGTAYDCgTZs2lfq5mzJlyl2tmOWpbMvC7VasWIHvvvsOU6ZMqVTrwu0/d1lZWSgoKECXLl3w1VdfISYmBhEREcbvu7m5lalbpVKhXbt2Ze7Xxo0b4e/vjyeffNJ4zMXFBS+99BKmTJlS6fcxfvz4Ml//61//wqJFi7Bx40a0aNHinteNGDECDz300H2fvzJ/34YOHYqhQ4cavx48eDD69OmDrl274oMPPrirO5jkxXBDJteuXbsyA/CeeuoptGrVChMmTMDAgQOhUqlw/vx5ZGZmwsfHp9znuHbt2n1fx2Aw4LPPPsOiRYsQHx9fpk+8om6FysrOzgYAuLu73/Ocjz/+GCNHjkRgYCAiIyPRv39/jBgxAiEhIZV+nTtnYlQkICAArq6uZY6VBsmLFy+aLNwkJCQAABo1anTX95o0aYLNmzcjJyenTG316tUrc17ph3V6evo9ww1QEoJKu7327Nk
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"bd = BetaDistribution('Beta', 2, 5)\n",
"bd.generate_samples(100)\n",
"bd.visualize()"
],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}