aboutsummaryrefslogtreecommitdiff
path: root/src/model.py
diff options
context:
space:
mode:
authorCharles <sircharlesaze@gmail.com>2020-01-23 14:39:15 +0100
committerCharles <sircharlesaze@gmail.com>2020-01-23 14:41:33 +0100
commitb21e6642c591962974212be2dbc965793df6bd06 (patch)
tree288785140417770e5e985efcaff3c48d6f9bdd64 /src/model.py
parent4537bdbebb5fe64c50080e7874d407f10a0676b7 (diff)
downloadft_linear_regression-b21e6642c591962974212be2dbc965793df6bd06.tar.gz
ft_linear_regression-b21e6642c591962974212be2dbc965793df6bd06.tar.bz2
ft_linear_regression-b21e6642c591962974212be2dbc965793df6bd06.zip
Fixing Model, still normalization problem
Diffstat (limited to 'src/model.py')
-rw-r--r--src/model.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/src/model.py b/src/model.py
index 8fb34a9..ee43050 100644
--- a/src/model.py
+++ b/src/model.py
@@ -1,6 +1,6 @@
import numpy as np
import matplotlib.pyplot as plt
-from sklearn.preprocessing import normalize as sklearn_normalize
+import sklearn.preprocessing
class Model:
@@ -11,7 +11,7 @@ class Model:
self.xs, self.ys = self._read_data()
def train(self, alpha=1, epoch=100):
- self.xs, self.ys = self._normalize_data()
+ self._normalize_data()
for _ in range(epoch):
next_theta0 = self.theta0 - alpha * self._partial_theta0()
next_theta1 = self.theta1 - alpha * self._partial_theta1()
@@ -19,8 +19,8 @@ class Model:
self.theta1 = next_theta1
def write_theta(self):
- with open(self.datafilename, "w") as file:
- file.write("{},{}".format(str(theta1), str(theta0)))
+ with open(self.thetafilename, "w") as file:
+ file.write("{},{}".format(str(self.theta1), str(self.theta0)))
def hypothesis(self, x):
return x * self.theta1 + self.theta0
@@ -54,7 +54,7 @@ class Model:
for x, y in zip(self.xs, self.ys)]) / len(self.xs)
def _normalize_data(self):
- self.xs, self.ys = sklearn_normalize([self.xs, self.ys])
+ self.xs, self.ys = sklearn.preprocessing.normalize([self.xs, self.ys])
def _read_theta(self):
try:
@@ -62,13 +62,13 @@ class Model:
strs = file.read().strip().split(",")
if len(strs) != 2:
raise "wrong theta file format"
- return int(strs[0]), int(strs[1])
+ return float(strs[0]), float(strs[1])
except IOError:
print(self.thetafilename, "do not exist")
def _read_data(self):
try:
- data = np.genfromtxt(self.datafilename, delimiter=",")
+ data = np.genfromtxt(self.datafilename, delimiter=",")[1:]
return data[:, 0], data[:, 1]
except IOError:
print(self.datafilename, "do not exist")