aboutsummaryrefslogtreecommitdiff
path: root/src/model.py
diff options
context:
space:
mode:
authorCharles <sircharlesaze@gmail.com>2020-01-23 19:11:47 +0100
committerCharles <sircharlesaze@gmail.com>2020-01-23 19:11:47 +0100
commit8902768c980aec4d2c0ae63cecfb5de24bb573f6 (patch)
tree6989d0b2a10e3b032a4e88ac1ae2c210f75b8f2c /src/model.py
parentb21e6642c591962974212be2dbc965793df6bd06 (diff)
downloadft_linear_regression-8902768c980aec4d2c0ae63cecfb5de24bb573f6.tar.gz
ft_linear_regression-8902768c980aec4d2c0ae63cecfb5de24bb573f6.tar.bz2
ft_linear_regression-8902768c980aec4d2c0ae63cecfb5de24bb573f6.zip
Normalized x prediction to actualy predict something
Diffstat (limited to 'src/model.py')
-rw-r--r--src/model.py17
1 files changed, 7 insertions, 10 deletions
diff --git a/src/model.py b/src/model.py
index ee43050..0d74a69 100644
--- a/src/model.py
+++ b/src/model.py
@@ -1,6 +1,5 @@
import numpy as np
import matplotlib.pyplot as plt
-import sklearn.preprocessing
class Model:
@@ -25,6 +24,10 @@ class Model:
def hypothesis(self, x):
return x * self.theta1 + self.theta0
+ def make_prediction(self, predict):
+ predict = (predict - self.xs.min()) / (self.xs.max() - self.xs.min())
+ return self.hypothesis(predict)
+
def cost(self):
return (1 / (2 * len(self.xs))) * sum([(self.hypothesis(x) - y) ** 2
for x, y in zip(self.xs, self.ys)])
@@ -42,8 +45,8 @@ class Model:
def _plot_model(self):
line_xs = [self.xs.min(), self.xs.max()]
- line_ys = [self.hypothesis(x) for x in line_xs]
- self.ax.plot(line_xs, line_ys)
+ line_ys = [self.make_prediction(x) for x in line_xs]
+ self.ax.plot(line_xs, line_ys, color='r')
def _partial_theta1(self):
return sum([(self.hypothesis(x) - y) * x
@@ -54,7 +57,7 @@ class Model:
for x, y in zip(self.xs, self.ys)]) / len(self.xs)
def _normalize_data(self):
- self.xs, self.ys = sklearn.preprocessing.normalize([self.xs, self.ys])
+ self.xs = (self.xs - self.xs.min()) / (self.xs.max() - self.xs.min())
def _read_theta(self):
try:
@@ -72,9 +75,3 @@ class Model:
return data[:, 0], data[:, 1]
except IOError:
print(self.datafilename, "do not exist")
-
-
-if __name__ == "__main__":
- m = Model()
- m.train()
- m.write_theta()