diff options
| author | Charles <sircharlesaze@gmail.com> | 2020-01-23 09:57:39 +0100 |
|---|---|---|
| committer | Charles <sircharlesaze@gmail.com> | 2020-01-23 09:57:39 +0100 |
| commit | 4537bdbebb5fe64c50080e7874d407f10a0676b7 (patch) | |
| tree | 8466528ae9c3865e2543e7fdb9e9301bfef40490 /src/cli.py | |
| parent | 1e45848f0b84218dfbdb62e313b4c33791a98555 (diff) | |
| download | ft_linear_regression-4537bdbebb5fe64c50080e7874d407f10a0676b7.tar.gz ft_linear_regression-4537bdbebb5fe64c50080e7874d407f10a0676b7.tar.bz2 ft_linear_regression-4537bdbebb5fe64c50080e7874d407f10a0676b7.zip | |
WIP: CLI interface interact with Model class, subprogram call Model methods to satisfy the subject
Diffstat (limited to 'src/cli.py')
| -rw-r--r-- | src/cli.py | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/src/cli.py b/src/cli.py new file mode 100644 index 0000000..ce759be --- /dev/null +++ b/src/cli.py @@ -0,0 +1,63 @@ +import sys +import argparse + +from model import Model +import predict + +class CommandLineInterface: + def __init__(self): + self.model = Model() + + def parse_args(self): + parser = argparse.ArgumentParser( + prog="ft_linear_regression_cli", + description="CLI to interact with the ft_linear_regression project" + ) + subparsers = parser.add_subparsers(help="sub-command help", dest="subparser_name") + + parser_train = subparsers.add_parser("train", help="train the model") + parser_train.set_defaults(func=self._train) + parser_train.add_argument("-a --alpha", type=float, default=1.0, dest="alpha", help="learning rate") + parser_train.add_argument("-e --epoch", type=int, default=100, dest="epoch", help="number of iterations") + + parser_predict = subparsers.add_parser("predict", help="make a predict") + parser_predict.set_defaults(func=self._predict) + parser_predict.add_argument("-x", type=int, help="mileage for which the prediction will be made") + + parser_cost = subparsers.add_parser("cost", help="print model cost") + parser_cost.set_defaults(func=self._cost) + + parser_plot = subparsers.add_parser("plot", help="plot data and model") + parser_plot.set_defaults(func=self._plot) + parser_plot.add_argument("-d --data", help="only plot data", action="store_true", dest="plot_data") + parser_plot.add_argument("-m --model", help="only plot model", action="store_true", dest="plot_model") + + self.args = parser.parse_args(sys.argv[1:]) + + def _train(self): + self.model.train(self.args.alpha, self.args.epoch) + self.model.write_theta() + + def _predict(self): + if self.args.x is not None: + print(self.model.hypothesis(self.args.x)) + else: + predict.predict_input(self.model) + + def _cost(self): + print("Cost:", self.model.cost()) + + def _plot(self): + if not self.args.plot_data and not self.args.plot_model: + self.model.plot() + else: + self.model.plot(self.args.plot_data, self.args.plot_model) + + def exec_args(self): + self.args.func() + + +if __name__ == "__main__": + cli = CommandLineInterface() + cli.parse_args() + cli.exec_args() |
