From 4537bdbebb5fe64c50080e7874d407f10a0676b7 Mon Sep 17 00:00:00 2001 From: Charles Date: Thu, 23 Jan 2020 09:57:39 +0100 Subject: WIP: CLI interface interact with Model class, subprogram call Model methods to satisfy the subject --- src/cli.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 src/cli.py (limited to 'src/cli.py') diff --git a/src/cli.py b/src/cli.py new file mode 100644 index 0000000..ce759be --- /dev/null +++ b/src/cli.py @@ -0,0 +1,63 @@ +import sys +import argparse + +from model import Model +import predict + +class CommandLineInterface: + def __init__(self): + self.model = Model() + + def parse_args(self): + parser = argparse.ArgumentParser( + prog="ft_linear_regression_cli", + description="CLI to interact with the ft_linear_regression project" + ) + subparsers = parser.add_subparsers(help="sub-command help", dest="subparser_name") + + parser_train = subparsers.add_parser("train", help="train the model") + parser_train.set_defaults(func=self._train) + parser_train.add_argument("-a --alpha", type=float, default=1.0, dest="alpha", help="learning rate") + parser_train.add_argument("-e --epoch", type=int, default=100, dest="epoch", help="number of iterations") + + parser_predict = subparsers.add_parser("predict", help="make a predict") + parser_predict.set_defaults(func=self._predict) + parser_predict.add_argument("-x", type=int, help="mileage for which the prediction will be made") + + parser_cost = subparsers.add_parser("cost", help="print model cost") + parser_cost.set_defaults(func=self._cost) + + parser_plot = subparsers.add_parser("plot", help="plot data and model") + parser_plot.set_defaults(func=self._plot) + parser_plot.add_argument("-d --data", help="only plot data", action="store_true", dest="plot_data") + parser_plot.add_argument("-m --model", help="only plot model", action="store_true", dest="plot_model") + + self.args = parser.parse_args(sys.argv[1:]) + + def _train(self): + self.model.train(self.args.alpha, self.args.epoch) + self.model.write_theta() + + def _predict(self): + if self.args.x is not None: + print(self.model.hypothesis(self.args.x)) + else: + predict.predict_input(self.model) + + def _cost(self): + print("Cost:", self.model.cost()) + + def _plot(self): + if not self.args.plot_data and not self.args.plot_model: + self.model.plot() + else: + self.model.plot(self.args.plot_data, self.args.plot_model) + + def exec_args(self): + self.args.func() + + +if __name__ == "__main__": + cli = CommandLineInterface() + cli.parse_args() + cli.exec_args() -- cgit