aboutsummaryrefslogtreecommitdiff
path: root/src/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/model.py')
-rw-r--r--src/model.py17
1 files changed, 16 insertions, 1 deletions
diff --git a/src/model.py b/src/model.py
index f777b16..47a660c 100644
--- a/src/model.py
+++ b/src/model.py
@@ -6,6 +6,18 @@ class Model:
for _ in range(epoch):
theta = theta - alpha * self.gradient(xs, ys)
+ def train_against(self, xs, ys, theta, one, alpha, epoch):
+ ys_ally = ys.copy()
+ ys_ally[ys == one] = 0
+ ys_ally[ys != one] = 1
+ return gradient_descent(xs, ys_ally, theta, alpha, epoch)
+
+ def train_thetas(xs, ys, theta, alpha=1, epoch=1000):
+ thetas = []
+ for i in np.unique(ys):
+ thetas.append(train_against(xs, ys, theta, i, alpha, epoch))
+ return thetas
+
def gradient(self, xs, ys):
return np.array([self.partial(xs, ys, i) for i in range(len(self.theta))])
@@ -18,8 +30,11 @@ class Model:
total += temp
return total / len(xs)
+ def predict(self, x):
+ return 1 if self.hypothesis(x) >= 0.5 else 0
+
def hypothesis(self, x):
- return 1 if self._sigmoid(x.dot(self.theta)) >= 0.5 else 0
+ return self._sigmoid(x.dot(self.theta))
def logloss(self, x, y):
if y == 1: