From 9f04fafe42fd6436bec09696e1bc8b2abc496cc4 Mon Sep 17 00:00:00 2001 From: Charles Date: Sat, 25 Jan 2020 14:16:00 +0100 Subject: Dataset parent of Analysis, scatter plot and pair_plot dirty scripts --- src/analysis.py | 29 ++++++++++++++++++++++++++--- src/dataset.py | 21 +++++++++++++++++++++ src/describe.py | 15 +++------------ src/pair_plot.py | 6 ++++++ src/scatter_plot.py | 6 ++++++ 5 files changed, 62 insertions(+), 15 deletions(-) create mode 100644 src/dataset.py create mode 100644 src/pair_plot.py create mode 100644 src/scatter_plot.py (limited to 'src') diff --git a/src/analysis.py b/src/analysis.py index abc0ffb..b6c9eb9 100644 --- a/src/analysis.py +++ b/src/analysis.py @@ -1,12 +1,14 @@ import numpy as np import pandas as pd +import matplotlib.pyplot as plt +from dataset import Dataset import dslr_stat -class Analysis: - def __init__(self, df): - self.df = df +class Analysis(Dataset): + def __init__(self, path): + super().__init__(path) def describe(self): desc_df = pd.DataFrame( @@ -24,3 +26,24 @@ class Analysis: desc_df.loc['75%', col] = dslr_stat.q75(self.df[col]) desc_df.loc['Max', col] = dslr_stat.max(self.df[col]) print(desc_df) + + def hist(self): + pass + + def scatter(self): + plt.scatter(self.df['astronomy'], self.df['defense_against_the_dark_arts']) + plt.show() + + def pair_plot(self): + scores = self.df_scores + fig, axis = plt.subplots(nrows=scores.shape[1], + ncols=scores.shape[1]) + for i, col in enumerate(scores.columns): + for j, pair_col in enumerate(scores.columns): + ax = axis[i, j] + if pair_col == col: + ax.hist(scores) + continue + ax.scatter(scores[col], scores[pair_col]) + plt.tight_layout() + plt.show() diff --git a/src/dataset.py b/src/dataset.py new file mode 100644 index 0000000..650d334 --- /dev/null +++ b/src/dataset.py @@ -0,0 +1,21 @@ +import pandas as pd + + +class Dataset: + def __init__(self, path): + self.path = path + try: + self.df = pd.read_csv(path) + except FileNotFoundError: + raise "Couldn't find dataset at: {}".format(path) + self.df.drop(columns=['Index'], inplace=True) + self.df.dropna(inplace=True) + self.df.columns = self.df.columns.str.lower() + self.df.columns = self.df.columns.str.replace(' ', '_') + self.df.rename(columns={'hogwarts_house': 'house'}, inplace=True) + + @property + def df_scores(self): + return self.df.loc[:, 'arithmancy':'flying'] + + diff --git a/src/describe.py b/src/describe.py index 7a968f1..4a3c5bc 100644 --- a/src/describe.py +++ b/src/describe.py @@ -1,20 +1,11 @@ import sys -import pandas as pd from analysis import Analysis if __name__ == "__main__": if len(sys.argv) != 2: - print("Usage: {} dataset_path".format(sys.argv[0])) - sys.exit(1) - try: - df = pd.read_csv(sys.argv[1]) - except FileNotFoundError: - print("Could not find dataset at: {}".format(sys.argv[1])) - sys.exit(1) - df = df.loc[:, 'Arithmancy':'Flying'] - df.dropna(inplace=True) - a = Analysis(df) + raise "Usage: {} dataset_path".format(sys.argv[0]) + a = Analysis(sys.argv[1]) a.describe() - print(df.describe()) + print(a.df_scores.describe()) diff --git a/src/pair_plot.py b/src/pair_plot.py new file mode 100644 index 0000000..bf0c632 --- /dev/null +++ b/src/pair_plot.py @@ -0,0 +1,6 @@ +from analysis import Analysis + + +if __name__ == '__main__': + a = Analysis('../datasets/dataset_train.csv') + a.pair_plot() diff --git a/src/scatter_plot.py b/src/scatter_plot.py new file mode 100644 index 0000000..74e0384 --- /dev/null +++ b/src/scatter_plot.py @@ -0,0 +1,6 @@ +from analysis import Analysis + + +if __name__ == '__main__': + a = Analysis('../datasets/dataset_train.csv') + a.scatter() -- cgit