aboutsummaryrefslogtreecommitdiff
path: root/src/analysis.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/analysis.py')
-rw-r--r--src/analysis.py29
1 files changed, 26 insertions, 3 deletions
diff --git a/src/analysis.py b/src/analysis.py
index abc0ffb..b6c9eb9 100644
--- a/src/analysis.py
+++ b/src/analysis.py
@@ -1,12 +1,14 @@
import numpy as np
import pandas as pd
+import matplotlib.pyplot as plt
+from dataset import Dataset
import dslr_stat
-class Analysis:
- def __init__(self, df):
- self.df = df
+class Analysis(Dataset):
+ def __init__(self, path):
+ super().__init__(path)
def describe(self):
desc_df = pd.DataFrame(
@@ -24,3 +26,24 @@ class Analysis:
desc_df.loc['75%', col] = dslr_stat.q75(self.df[col])
desc_df.loc['Max', col] = dslr_stat.max(self.df[col])
print(desc_df)
+
+ def hist(self):
+ pass
+
+ def scatter(self):
+ plt.scatter(self.df['astronomy'], self.df['defense_against_the_dark_arts'])
+ plt.show()
+
+ def pair_plot(self):
+ scores = self.df_scores
+ fig, axis = plt.subplots(nrows=scores.shape[1],
+ ncols=scores.shape[1])
+ for i, col in enumerate(scores.columns):
+ for j, pair_col in enumerate(scores.columns):
+ ax = axis[i, j]
+ if pair_col == col:
+ ax.hist(scores)
+ continue
+ ax.scatter(scores[col], scores[pair_col])
+ plt.tight_layout()
+ plt.show()