diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/model.py | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/src/model.py b/src/model.py index f777b16..47a660c 100644 --- a/src/model.py +++ b/src/model.py @@ -6,6 +6,18 @@ class Model: for _ in range(epoch): theta = theta - alpha * self.gradient(xs, ys) + def train_against(self, xs, ys, theta, one, alpha, epoch): + ys_ally = ys.copy() + ys_ally[ys == one] = 0 + ys_ally[ys != one] = 1 + return gradient_descent(xs, ys_ally, theta, alpha, epoch) + + def train_thetas(xs, ys, theta, alpha=1, epoch=1000): + thetas = [] + for i in np.unique(ys): + thetas.append(train_against(xs, ys, theta, i, alpha, epoch)) + return thetas + def gradient(self, xs, ys): return np.array([self.partial(xs, ys, i) for i in range(len(self.theta))]) @@ -18,8 +30,11 @@ class Model: total += temp return total / len(xs) + def predict(self, x): + return 1 if self.hypothesis(x) >= 0.5 else 0 + def hypothesis(self, x): - return 1 if self._sigmoid(x.dot(self.theta)) >= 0.5 else 0 + return self._sigmoid(x.dot(self.theta)) def logloss(self, x, y): if y == 1: |
