aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorCharles <sircharlesaze@gmail.com>2020-01-25 14:16:00 +0100
committerCharles <sircharlesaze@gmail.com>2020-01-25 14:16:00 +0100
commit9f04fafe42fd6436bec09696e1bc8b2abc496cc4 (patch)
treeefeba48e71f0053e63578d35204542f61118ff1b /src
parentdea0f4cdec5bdf24962c8ab3ab2a6473e202259a (diff)
downloaddslr-9f04fafe42fd6436bec09696e1bc8b2abc496cc4.tar.gz
dslr-9f04fafe42fd6436bec09696e1bc8b2abc496cc4.tar.bz2
dslr-9f04fafe42fd6436bec09696e1bc8b2abc496cc4.zip
Dataset parent of Analysis, scatter plot and pair_plot dirty scripts
Diffstat (limited to 'src')
-rw-r--r--src/analysis.py29
-rw-r--r--src/dataset.py21
-rw-r--r--src/describe.py15
-rw-r--r--src/pair_plot.py6
-rw-r--r--src/scatter_plot.py6
5 files changed, 62 insertions, 15 deletions
diff --git a/src/analysis.py b/src/analysis.py
index abc0ffb..b6c9eb9 100644
--- a/src/analysis.py
+++ b/src/analysis.py
@@ -1,12 +1,14 @@
import numpy as np
import pandas as pd
+import matplotlib.pyplot as plt
+from dataset import Dataset
import dslr_stat
-class Analysis:
- def __init__(self, df):
- self.df = df
+class Analysis(Dataset):
+ def __init__(self, path):
+ super().__init__(path)
def describe(self):
desc_df = pd.DataFrame(
@@ -24,3 +26,24 @@ class Analysis:
desc_df.loc['75%', col] = dslr_stat.q75(self.df[col])
desc_df.loc['Max', col] = dslr_stat.max(self.df[col])
print(desc_df)
+
+ def hist(self):
+ pass
+
+ def scatter(self):
+ plt.scatter(self.df['astronomy'], self.df['defense_against_the_dark_arts'])
+ plt.show()
+
+ def pair_plot(self):
+ scores = self.df_scores
+ fig, axis = plt.subplots(nrows=scores.shape[1],
+ ncols=scores.shape[1])
+ for i, col in enumerate(scores.columns):
+ for j, pair_col in enumerate(scores.columns):
+ ax = axis[i, j]
+ if pair_col == col:
+ ax.hist(scores)
+ continue
+ ax.scatter(scores[col], scores[pair_col])
+ plt.tight_layout()
+ plt.show()
diff --git a/src/dataset.py b/src/dataset.py
new file mode 100644
index 0000000..650d334
--- /dev/null
+++ b/src/dataset.py
@@ -0,0 +1,21 @@
+import pandas as pd
+
+
+class Dataset:
+ def __init__(self, path):
+ self.path = path
+ try:
+ self.df = pd.read_csv(path)
+ except FileNotFoundError:
+ raise "Couldn't find dataset at: {}".format(path)
+ self.df.drop(columns=['Index'], inplace=True)
+ self.df.dropna(inplace=True)
+ self.df.columns = self.df.columns.str.lower()
+ self.df.columns = self.df.columns.str.replace(' ', '_')
+ self.df.rename(columns={'hogwarts_house': 'house'}, inplace=True)
+
+ @property
+ def df_scores(self):
+ return self.df.loc[:, 'arithmancy':'flying']
+
+
diff --git a/src/describe.py b/src/describe.py
index 7a968f1..4a3c5bc 100644
--- a/src/describe.py
+++ b/src/describe.py
@@ -1,20 +1,11 @@
import sys
-import pandas as pd
from analysis import Analysis
if __name__ == "__main__":
if len(sys.argv) != 2:
- print("Usage: {} dataset_path".format(sys.argv[0]))
- sys.exit(1)
- try:
- df = pd.read_csv(sys.argv[1])
- except FileNotFoundError:
- print("Could not find dataset at: {}".format(sys.argv[1]))
- sys.exit(1)
- df = df.loc[:, 'Arithmancy':'Flying']
- df.dropna(inplace=True)
- a = Analysis(df)
+ raise "Usage: {} dataset_path".format(sys.argv[0])
+ a = Analysis(sys.argv[1])
a.describe()
- print(df.describe())
+ print(a.df_scores.describe())
diff --git a/src/pair_plot.py b/src/pair_plot.py
new file mode 100644
index 0000000..bf0c632
--- /dev/null
+++ b/src/pair_plot.py
@@ -0,0 +1,6 @@
+from analysis import Analysis
+
+
+if __name__ == '__main__':
+ a = Analysis('../datasets/dataset_train.csv')
+ a.pair_plot()
diff --git a/src/scatter_plot.py b/src/scatter_plot.py
new file mode 100644
index 0000000..74e0384
--- /dev/null
+++ b/src/scatter_plot.py
@@ -0,0 +1,6 @@
+from analysis import Analysis
+
+
+if __name__ == '__main__':
+ a = Analysis('../datasets/dataset_train.csv')
+ a.scatter()