aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/model.py14
-rw-r--r--src/theta2
2 files changed, 8 insertions, 8 deletions
diff --git a/src/model.py b/src/model.py
index 8fb34a9..ee43050 100644
--- a/src/model.py
+++ b/src/model.py
@@ -1,6 +1,6 @@
import numpy as np
import matplotlib.pyplot as plt
-from sklearn.preprocessing import normalize as sklearn_normalize
+import sklearn.preprocessing
class Model:
@@ -11,7 +11,7 @@ class Model:
self.xs, self.ys = self._read_data()
def train(self, alpha=1, epoch=100):
- self.xs, self.ys = self._normalize_data()
+ self._normalize_data()
for _ in range(epoch):
next_theta0 = self.theta0 - alpha * self._partial_theta0()
next_theta1 = self.theta1 - alpha * self._partial_theta1()
@@ -19,8 +19,8 @@ class Model:
self.theta1 = next_theta1
def write_theta(self):
- with open(self.datafilename, "w") as file:
- file.write("{},{}".format(str(theta1), str(theta0)))
+ with open(self.thetafilename, "w") as file:
+ file.write("{},{}".format(str(self.theta1), str(self.theta0)))
def hypothesis(self, x):
return x * self.theta1 + self.theta0
@@ -54,7 +54,7 @@ class Model:
for x, y in zip(self.xs, self.ys)]) / len(self.xs)
def _normalize_data(self):
- self.xs, self.ys = sklearn_normalize([self.xs, self.ys])
+ self.xs, self.ys = sklearn.preprocessing.normalize([self.xs, self.ys])
def _read_theta(self):
try:
@@ -62,13 +62,13 @@ class Model:
strs = file.read().strip().split(",")
if len(strs) != 2:
raise "wrong theta file format"
- return int(strs[0]), int(strs[1])
+ return float(strs[0]), float(strs[1])
except IOError:
print(self.thetafilename, "do not exist")
def _read_data(self):
try:
- data = np.genfromtxt(self.datafilename, delimiter=",")
+ data = np.genfromtxt(self.datafilename, delimiter=",")[1:]
return data[:, 0], data[:, 1]
except IOError:
print(self.datafilename, "do not exist")
diff --git a/src/theta b/src/theta
index 15794e0..768218a 100644
--- a/src/theta
+++ b/src/theta
@@ -1 +1 @@
-0,0
+-0.19808642684030997,0.23574835208002437 \ No newline at end of file