update assignment 5
This commit is contained in:
parent
240b9d8695
commit
89a3d15e0e
@ -75,15 +75,28 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Define the perceptron algorithm\n",
|
"# Define the perceptron algorithm\n",
|
||||||
"class Perceptron:\n",
|
"class MultiClassPerceptron:\n",
|
||||||
" def __init__(self, learning_rate=0.01, n_iters=1000):\n",
|
" def __init__(self, input_dim, output_dim, lr=0.01, epochs=1000):\n",
|
||||||
|
" self.W = np.random.randn(input_dim, output_dim)\n",
|
||||||
|
" self.b = np.zeros((1, output_dim))\n",
|
||||||
|
" self.lr = lr\n",
|
||||||
|
" self.epochs = epochs\n",
|
||||||
|
"\n",
|
||||||
|
" def forward(self, X):\n",
|
||||||
|
" # ToDo: implement the forward() function\n",
|
||||||
" pass\n",
|
" pass\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # define the fit function to train the model\n",
|
" def backward(self, X, y):\n",
|
||||||
|
" # ToDo: implement the backward() function\n",
|
||||||
|
" pass\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # define the predict function to predict labels\n",
|
" def fit(self, X, y):\n",
|
||||||
|
" for epoch in range(self.epochs):\n",
|
||||||
|
" self.forward(X)\n",
|
||||||
|
" self.backward(X, y)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def _unit_step_func(self, x):\n",
|
" def predict(self, X):\n",
|
||||||
|
" # ToDo: implement the predict() function\n",
|
||||||
" pass"
|
" pass"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -105,7 +118,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Train the model\n",
|
"# Train the model\n",
|
||||||
"p = Perceptron(learning_rate=0.01, n_iters=1000)\n",
|
"p = MultiClassPerceptron(input_dim=X_train.shape[1], output_dim=3, lr=0.01, epochs=1000)\n",
|
||||||
"p.fit(X_train, y_train)\n",
|
"p.fit(X_train, y_train)\n",
|
||||||
"predictions_train = p.predict(X_train)\n",
|
"predictions_train = p.predict(X_train)\n",
|
||||||
"predictions = p.predict(X_test)"
|
"predictions = p.predict(X_test)"
|
||||||
|
Loading…
Reference in New Issue
Block a user