diff options
| author | Charles <sircharlesaze@gmail.com> | 2020-01-23 14:39:15 +0100 |
|---|---|---|
| committer | Charles <sircharlesaze@gmail.com> | 2020-01-23 14:41:33 +0100 |
| commit | b21e6642c591962974212be2dbc965793df6bd06 (patch) | |
| tree | 288785140417770e5e985efcaff3c48d6f9bdd64 /src/model.py | |
| parent | 4537bdbebb5fe64c50080e7874d407f10a0676b7 (diff) | |
| download | ft_linear_regression-b21e6642c591962974212be2dbc965793df6bd06.tar.gz ft_linear_regression-b21e6642c591962974212be2dbc965793df6bd06.tar.bz2 ft_linear_regression-b21e6642c591962974212be2dbc965793df6bd06.zip | |
Fixing Model, still normalization problem
Diffstat (limited to 'src/model.py')
| -rw-r--r-- | src/model.py | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/src/model.py b/src/model.py index 8fb34a9..ee43050 100644 --- a/src/model.py +++ b/src/model.py @@ -1,6 +1,6 @@ import numpy as np import matplotlib.pyplot as plt -from sklearn.preprocessing import normalize as sklearn_normalize +import sklearn.preprocessing class Model: @@ -11,7 +11,7 @@ class Model: self.xs, self.ys = self._read_data() def train(self, alpha=1, epoch=100): - self.xs, self.ys = self._normalize_data() + self._normalize_data() for _ in range(epoch): next_theta0 = self.theta0 - alpha * self._partial_theta0() next_theta1 = self.theta1 - alpha * self._partial_theta1() @@ -19,8 +19,8 @@ class Model: self.theta1 = next_theta1 def write_theta(self): - with open(self.datafilename, "w") as file: - file.write("{},{}".format(str(theta1), str(theta0))) + with open(self.thetafilename, "w") as file: + file.write("{},{}".format(str(self.theta1), str(self.theta0))) def hypothesis(self, x): return x * self.theta1 + self.theta0 @@ -54,7 +54,7 @@ class Model: for x, y in zip(self.xs, self.ys)]) / len(self.xs) def _normalize_data(self): - self.xs, self.ys = sklearn_normalize([self.xs, self.ys]) + self.xs, self.ys = sklearn.preprocessing.normalize([self.xs, self.ys]) def _read_theta(self): try: @@ -62,13 +62,13 @@ class Model: strs = file.read().strip().split(",") if len(strs) != 2: raise "wrong theta file format" - return int(strs[0]), int(strs[1]) + return float(strs[0]), float(strs[1]) except IOError: print(self.thetafilename, "do not exist") def _read_data(self): try: - data = np.genfromtxt(self.datafilename, delimiter=",") + data = np.genfromtxt(self.datafilename, delimiter=",")[1:] return data[:, 0], data[:, 1] except IOError: print(self.datafilename, "do not exist") |
