aboutsummaryrefslogtreecommitdiff
path: root/src/cli.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/cli.py')
-rw-r--r--src/cli.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/src/cli.py b/src/cli.py
new file mode 100644
index 0000000..ce759be
--- /dev/null
+++ b/src/cli.py
@@ -0,0 +1,63 @@
+import sys
+import argparse
+
+from model import Model
+import predict
+
+class CommandLineInterface:
+ def __init__(self):
+ self.model = Model()
+
+ def parse_args(self):
+ parser = argparse.ArgumentParser(
+ prog="ft_linear_regression_cli",
+ description="CLI to interact with the ft_linear_regression project"
+ )
+ subparsers = parser.add_subparsers(help="sub-command help", dest="subparser_name")
+
+ parser_train = subparsers.add_parser("train", help="train the model")
+ parser_train.set_defaults(func=self._train)
+ parser_train.add_argument("-a --alpha", type=float, default=1.0, dest="alpha", help="learning rate")
+ parser_train.add_argument("-e --epoch", type=int, default=100, dest="epoch", help="number of iterations")
+
+ parser_predict = subparsers.add_parser("predict", help="make a predict")
+ parser_predict.set_defaults(func=self._predict)
+ parser_predict.add_argument("-x", type=int, help="mileage for which the prediction will be made")
+
+ parser_cost = subparsers.add_parser("cost", help="print model cost")
+ parser_cost.set_defaults(func=self._cost)
+
+ parser_plot = subparsers.add_parser("plot", help="plot data and model")
+ parser_plot.set_defaults(func=self._plot)
+ parser_plot.add_argument("-d --data", help="only plot data", action="store_true", dest="plot_data")
+ parser_plot.add_argument("-m --model", help="only plot model", action="store_true", dest="plot_model")
+
+ self.args = parser.parse_args(sys.argv[1:])
+
+ def _train(self):
+ self.model.train(self.args.alpha, self.args.epoch)
+ self.model.write_theta()
+
+ def _predict(self):
+ if self.args.x is not None:
+ print(self.model.hypothesis(self.args.x))
+ else:
+ predict.predict_input(self.model)
+
+ def _cost(self):
+ print("Cost:", self.model.cost())
+
+ def _plot(self):
+ if not self.args.plot_data and not self.args.plot_model:
+ self.model.plot()
+ else:
+ self.model.plot(self.args.plot_data, self.args.plot_model)
+
+ def exec_args(self):
+ self.args.func()
+
+
+if __name__ == "__main__":
+ cli = CommandLineInterface()
+ cli.parse_args()
+ cli.exec_args()