diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/model.py | 14 | ||||
| -rw-r--r-- | src/theta | 2 |
2 files changed, 8 insertions, 8 deletions
diff --git a/src/model.py b/src/model.py index 8fb34a9..ee43050 100644 --- a/src/model.py +++ b/src/model.py @@ -1,6 +1,6 @@ import numpy as np import matplotlib.pyplot as plt -from sklearn.preprocessing import normalize as sklearn_normalize +import sklearn.preprocessing class Model: @@ -11,7 +11,7 @@ class Model: self.xs, self.ys = self._read_data() def train(self, alpha=1, epoch=100): - self.xs, self.ys = self._normalize_data() + self._normalize_data() for _ in range(epoch): next_theta0 = self.theta0 - alpha * self._partial_theta0() next_theta1 = self.theta1 - alpha * self._partial_theta1() @@ -19,8 +19,8 @@ class Model: self.theta1 = next_theta1 def write_theta(self): - with open(self.datafilename, "w") as file: - file.write("{},{}".format(str(theta1), str(theta0))) + with open(self.thetafilename, "w") as file: + file.write("{},{}".format(str(self.theta1), str(self.theta0))) def hypothesis(self, x): return x * self.theta1 + self.theta0 @@ -54,7 +54,7 @@ class Model: for x, y in zip(self.xs, self.ys)]) / len(self.xs) def _normalize_data(self): - self.xs, self.ys = sklearn_normalize([self.xs, self.ys]) + self.xs, self.ys = sklearn.preprocessing.normalize([self.xs, self.ys]) def _read_theta(self): try: @@ -62,13 +62,13 @@ class Model: strs = file.read().strip().split(",") if len(strs) != 2: raise "wrong theta file format" - return int(strs[0]), int(strs[1]) + return float(strs[0]), float(strs[1]) except IOError: print(self.thetafilename, "do not exist") def _read_data(self): try: - data = np.genfromtxt(self.datafilename, delimiter=",") + data = np.genfromtxt(self.datafilename, delimiter=",")[1:] return data[:, 0], data[:, 1] except IOError: print(self.datafilename, "do not exist") @@ -1 +1 @@ -0,0 +-0.19808642684030997,0.23574835208002437
\ No newline at end of file |
