diff options
| author | Charles <sircharlesaze@gmail.com> | 2020-01-23 09:57:39 +0100 |
|---|---|---|
| committer | Charles <sircharlesaze@gmail.com> | 2020-01-23 09:57:39 +0100 |
| commit | 4537bdbebb5fe64c50080e7874d407f10a0676b7 (patch) | |
| tree | 8466528ae9c3865e2543e7fdb9e9301bfef40490 /src | |
| parent | 1e45848f0b84218dfbdb62e313b4c33791a98555 (diff) | |
| download | ft_linear_regression-4537bdbebb5fe64c50080e7874d407f10a0676b7.tar.gz ft_linear_regression-4537bdbebb5fe64c50080e7874d407f10a0676b7.tar.bz2 ft_linear_regression-4537bdbebb5fe64c50080e7874d407f10a0676b7.zip | |
WIP: CLI interface interact with Model class, subprogram call Model methods to satisfy the subject
Diffstat (limited to 'src')
| -rw-r--r-- | src/cli.py | 63 | ||||
| -rw-r--r-- | src/cost.py | 7 | ||||
| -rw-r--r-- | src/model.py | 80 | ||||
| -rw-r--r-- | src/predict.py | 29 | ||||
| -rw-r--r-- | src/train.py | 30 |
5 files changed, 163 insertions, 46 deletions
diff --git a/src/cli.py b/src/cli.py new file mode 100644 index 0000000..ce759be --- /dev/null +++ b/src/cli.py @@ -0,0 +1,63 @@ +import sys +import argparse + +from model import Model +import predict + +class CommandLineInterface: + def __init__(self): + self.model = Model() + + def parse_args(self): + parser = argparse.ArgumentParser( + prog="ft_linear_regression_cli", + description="CLI to interact with the ft_linear_regression project" + ) + subparsers = parser.add_subparsers(help="sub-command help", dest="subparser_name") + + parser_train = subparsers.add_parser("train", help="train the model") + parser_train.set_defaults(func=self._train) + parser_train.add_argument("-a --alpha", type=float, default=1.0, dest="alpha", help="learning rate") + parser_train.add_argument("-e --epoch", type=int, default=100, dest="epoch", help="number of iterations") + + parser_predict = subparsers.add_parser("predict", help="make a predict") + parser_predict.set_defaults(func=self._predict) + parser_predict.add_argument("-x", type=int, help="mileage for which the prediction will be made") + + parser_cost = subparsers.add_parser("cost", help="print model cost") + parser_cost.set_defaults(func=self._cost) + + parser_plot = subparsers.add_parser("plot", help="plot data and model") + parser_plot.set_defaults(func=self._plot) + parser_plot.add_argument("-d --data", help="only plot data", action="store_true", dest="plot_data") + parser_plot.add_argument("-m --model", help="only plot model", action="store_true", dest="plot_model") + + self.args = parser.parse_args(sys.argv[1:]) + + def _train(self): + self.model.train(self.args.alpha, self.args.epoch) + self.model.write_theta() + + def _predict(self): + if self.args.x is not None: + print(self.model.hypothesis(self.args.x)) + else: + predict.predict_input(self.model) + + def _cost(self): + print("Cost:", self.model.cost()) + + def _plot(self): + if not self.args.plot_data and not self.args.plot_model: + self.model.plot() + else: + self.model.plot(self.args.plot_data, self.args.plot_model) + + def exec_args(self): + self.args.func() + + +if __name__ == "__main__": + cli = CommandLineInterface() + cli.parse_args() + cli.exec_args() diff --git a/src/cost.py b/src/cost.py new file mode 100644 index 0000000..0f00253 --- /dev/null +++ b/src/cost.py @@ -0,0 +1,7 @@ +from model import Model + + +if __name__ == "__main__": + m = Model(thetafilename="./theta") + print("Cost:", m.cost()) + diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..8fb34a9 --- /dev/null +++ b/src/model.py @@ -0,0 +1,80 @@ +import numpy as np +import matplotlib.pyplot as plt +from sklearn.preprocessing import normalize as sklearn_normalize + + +class Model: + def __init__(self, datafilename="../data.csv", thetafilename="./theta"): + self.datafilename = datafilename + self.thetafilename = thetafilename + self.theta1, self.theta0 = self._read_theta() + self.xs, self.ys = self._read_data() + + def train(self, alpha=1, epoch=100): + self.xs, self.ys = self._normalize_data() + for _ in range(epoch): + next_theta0 = self.theta0 - alpha * self._partial_theta0() + next_theta1 = self.theta1 - alpha * self._partial_theta1() + self.theta0 = next_theta0 + self.theta1 = next_theta1 + + def write_theta(self): + with open(self.datafilename, "w") as file: + file.write("{},{}".format(str(theta1), str(theta0))) + + def hypothesis(self, x): + return x * self.theta1 + self.theta0 + + def cost(self): + return (1 / (2 * len(self.xs))) * sum([(self.hypothesis(x) - y) ** 2 + for x, y in zip(self.xs, self.ys)]) + + def plot(self, plot_data=True, plot_model=True): + self.fig, self.ax = plt.subplots() + if plot_data: + self._plot_data() + if plot_model: + self._plot_model() + plt.show() + + def _plot_data(self): + self.ax.scatter(self.xs, self.ys) + + def _plot_model(self): + line_xs = [self.xs.min(), self.xs.max()] + line_ys = [self.hypothesis(x) for x in line_xs] + self.ax.plot(line_xs, line_ys) + + def _partial_theta1(self): + return sum([(self.hypothesis(x) - y) * x + for x, y in zip(self.xs, self.ys)]) / len(self.xs) + + def _partial_theta0(self): + return sum([self.hypothesis(x) - y + for x, y in zip(self.xs, self.ys)]) / len(self.xs) + + def _normalize_data(self): + self.xs, self.ys = sklearn_normalize([self.xs, self.ys]) + + def _read_theta(self): + try: + with open(self.thetafilename, "r") as file: + strs = file.read().strip().split(",") + if len(strs) != 2: + raise "wrong theta file format" + return int(strs[0]), int(strs[1]) + except IOError: + print(self.thetafilename, "do not exist") + + def _read_data(self): + try: + data = np.genfromtxt(self.datafilename, delimiter=",") + return data[:, 0], data[:, 1] + except IOError: + print(self.datafilename, "do not exist") + + +if __name__ == "__main__": + m = Model() + m.train() + m.write_theta() diff --git a/src/predict.py b/src/predict.py index 85c7eac..329382a 100644 --- a/src/predict.py +++ b/src/predict.py @@ -1,22 +1,15 @@ -class Predictor: - def __init__(self, filename='theta'): - self.filename = filename - self.theta1, self.theta0 = self.read_theta() +from model import Model - def make_prediction(self, x): - return x * self.theta1 + self.theta0 - - def read_theta(self): +def predict_input(m): + while True: try: - with open(self.filename, 'r') as file: - strs = file.read().strip().split(",") - if len(strs) != 2: - raise "wrong theta file format" - return int(strs[0]), int(strs[1]) - except IOError: - print(self.filename, "do not exist") + x = int(input("Enter a mileage: ")) + except ValueError: + print("Bad input, you should enter a number") + else: + break + print("The predicted price for this mileage is", m.hypothesis(x)) if __name__ == "__main__": - p = Predictor() - x = int(input("Enter a mileage: ")) - print("The predicted price for this mileage is", p.make_prediction(x)) + m = Model(thetafilename="./theta") + predict_input(m) diff --git a/src/train.py b/src/train.py index a31d032..0a68916 100644 --- a/src/train.py +++ b/src/train.py @@ -1,32 +1,6 @@ -class Model: - def __init__(self, filename='../data.csv'): - self.datafile = filename - - def train(self): - pass - - def partial_theta1(self): - pass - - def partial_theta0(self): - pass - - def gradient_descent(self): - pass - - def read_data(self): - pass - - def normalize_data(self): - pass - - def write_theta(self): - pass +from model import Model if __name__ == "__main__": - m = Model() - m.read_data() - m.normalize_data() + m = Model(datafilename="../data.csv", thetafilename="./theta") m.train() - m.write_theta() |
