diff options
| author | Charles <sircharlesaze@gmail.com> | 2020-01-23 09:57:39 +0100 |
|---|---|---|
| committer | Charles <sircharlesaze@gmail.com> | 2020-01-23 09:57:39 +0100 |
| commit | 4537bdbebb5fe64c50080e7874d407f10a0676b7 (patch) | |
| tree | 8466528ae9c3865e2543e7fdb9e9301bfef40490 /src/model.py | |
| parent | 1e45848f0b84218dfbdb62e313b4c33791a98555 (diff) | |
| download | ft_linear_regression-4537bdbebb5fe64c50080e7874d407f10a0676b7.tar.gz ft_linear_regression-4537bdbebb5fe64c50080e7874d407f10a0676b7.tar.bz2 ft_linear_regression-4537bdbebb5fe64c50080e7874d407f10a0676b7.zip | |
WIP: CLI interface interact with Model class, subprogram call Model methods to satisfy the subject
Diffstat (limited to 'src/model.py')
| -rw-r--r-- | src/model.py | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..8fb34a9 --- /dev/null +++ b/src/model.py @@ -0,0 +1,80 @@ +import numpy as np +import matplotlib.pyplot as plt +from sklearn.preprocessing import normalize as sklearn_normalize + + +class Model: + def __init__(self, datafilename="../data.csv", thetafilename="./theta"): + self.datafilename = datafilename + self.thetafilename = thetafilename + self.theta1, self.theta0 = self._read_theta() + self.xs, self.ys = self._read_data() + + def train(self, alpha=1, epoch=100): + self.xs, self.ys = self._normalize_data() + for _ in range(epoch): + next_theta0 = self.theta0 - alpha * self._partial_theta0() + next_theta1 = self.theta1 - alpha * self._partial_theta1() + self.theta0 = next_theta0 + self.theta1 = next_theta1 + + def write_theta(self): + with open(self.datafilename, "w") as file: + file.write("{},{}".format(str(theta1), str(theta0))) + + def hypothesis(self, x): + return x * self.theta1 + self.theta0 + + def cost(self): + return (1 / (2 * len(self.xs))) * sum([(self.hypothesis(x) - y) ** 2 + for x, y in zip(self.xs, self.ys)]) + + def plot(self, plot_data=True, plot_model=True): + self.fig, self.ax = plt.subplots() + if plot_data: + self._plot_data() + if plot_model: + self._plot_model() + plt.show() + + def _plot_data(self): + self.ax.scatter(self.xs, self.ys) + + def _plot_model(self): + line_xs = [self.xs.min(), self.xs.max()] + line_ys = [self.hypothesis(x) for x in line_xs] + self.ax.plot(line_xs, line_ys) + + def _partial_theta1(self): + return sum([(self.hypothesis(x) - y) * x + for x, y in zip(self.xs, self.ys)]) / len(self.xs) + + def _partial_theta0(self): + return sum([self.hypothesis(x) - y + for x, y in zip(self.xs, self.ys)]) / len(self.xs) + + def _normalize_data(self): + self.xs, self.ys = sklearn_normalize([self.xs, self.ys]) + + def _read_theta(self): + try: + with open(self.thetafilename, "r") as file: + strs = file.read().strip().split(",") + if len(strs) != 2: + raise "wrong theta file format" + return int(strs[0]), int(strs[1]) + except IOError: + print(self.thetafilename, "do not exist") + + def _read_data(self): + try: + data = np.genfromtxt(self.datafilename, delimiter=",") + return data[:, 0], data[:, 1] + except IOError: + print(self.datafilename, "do not exist") + + +if __name__ == "__main__": + m = Model() + m.train() + m.write_theta() |
