From c2d4aaf8be815b5391b9d0ddb45cb850ebc5d407 Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Sat, 3 Dec 2022 22:23:29 +1100 Subject: [PATCH] Implement yli.auto_correlations --- yli/__init__.py | 4 +- yli/descriptives.py | 108 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 109 insertions(+), 3 deletions(-) diff --git a/yli/__init__.py b/yli/__init__.py index 61cf681..f5754b4 100644 --- a/yli/__init__.py +++ b/yli/__init__.py @@ -16,7 +16,7 @@ from .bayes_factors import bayesfactor_afbf from .config import config -from .descriptives import auto_descriptives +from .descriptives import auto_correlations, auto_descriptives from .distributions import beta_oddsratio, beta_ratio, hdi, transformed_dist from .io import pickle_read_compressed, pickle_read_encrypted, pickle_write_compressed, pickle_write_encrypted from .regress import OrdinalLogit, PenalisedLogit, logit_then_regress, regress, vif @@ -33,7 +33,7 @@ def reload_me(): try: importlib.reload(v) except ModuleNotFoundError as ex: - if ex.name.startswith('yli.'): + if ex.name == k: # Must be due to a module which we deleted - can safely ignore pass else: diff --git a/yli/descriptives.py b/yli/descriptives.py index a27b7ec..dadbe79 100644 --- a/yli/descriptives.py +++ b/yli/descriptives.py @@ -15,9 +15,11 @@ # along with this program. If not, see . import pandas as pd +from scipy import stats +import seaborn as sns from .config import config -from .utils import check_nan +from .utils import as_numeric, check_nan def auto_descriptives(df, cols, *, ordinal_range=[]): """ @@ -140,3 +142,107 @@ class AutoDescriptivesResult: result_labels_fmt = [r[0] for r in self._result_labels] table = pd.DataFrame(self._result_data, index=result_labels_fmt, columns=['', 'Missing']) return str(table) + +def auto_correlations(df, cols): + # TODO: Documentation + + def _col_to_numeric(col): + if col.dtype == 'category' and col.cat.ordered: + # Ordinal variable + # Factorise if required + col, _ = as_numeric(col) + + # Code as ranks + col[col >= 0] = stats.rankdata(col[col >= 0]) + + # Put NaNs back + col = col.astype('float64') + col[col < 0] = pd.NA + + return col + else: + # FIXME: Bools, binary, etc. + return col + + # Code columns as numeric/ranks/etc. as appropriate + df_coded = pd.DataFrame(index=df.index) + + for col_name in cols: + col = df[col_name] + + if col.dtype == 'category' and col.cat.ordered: + # Ordinal variable + # Factorise if required + col, _ = as_numeric(col) + + # Code as ranks + col[col >= 0] = stats.rankdata(col[col >= 0]) + + # Put NaNs back + col = col.astype('float64') + col[col < 0] = pd.NA + + df_coded[col_name] = col + elif col.dtype in ('bool', 'boolean', 'category', 'object'): + cat_values = col.dropna().unique() + + if len(cat_values) == 2: + # Categorical variable with 2 categories + # Code as 0/1/NA + cat_values = sorted(cat_values) + col = col.replace({cat_values[0]: 0, cat_values[1]: 1}) + df_coded[col_name] = col + else: + # Categorical variable with >2 categories + # Create dummy variables + dummies = pd.get_dummies(col, prefix=col_name) + df_coded = df_coded.join(dummies) + else: + # Numeric variable, etc. + df_coded[col_name] = col + + # Compute pairwise correlation + df_corr = pd.DataFrame(index=df_coded.columns, columns=df_coded.columns, dtype='float64') + + for i, col1 in enumerate(df_coded.columns): + for col2 in df_coded.columns[:i]: + statistic = stats.pearsonr(df_coded[col1], df_coded[col2]).statistic + df_corr.loc[col1, col2] = statistic + df_corr.loc[col2, col1] = statistic + + # Correlation with itself is always 1 + df_corr.loc[col1, col1] = 1 + + return AutoCorrelationsResult(df_corr) + +class AutoCorrelationsResult: + # TODO: Documentation + + def __init__(self, correlations): + self.correlations = correlations + + def __repr__(self): + if config.repr_is_summary: + return self.summary() + return super().__repr__() + + def _repr_html_(self): + df_repr = self.correlations._repr_html_() + + # Insert caption + idx_endopen = df_repr.index('>', df_repr.index('Correlation Matrix' + df_repr[idx_endopen+1:] + + return df_repr + + def summary(self): + """ + Return a stringified summary of the correlation matrix + + :rtype: str + """ + + return 'Correlation Matrix\n\n' + str(self.correlations) + + def plot(self): + sns.heatmap(self.correlations, vmin=-1, vmax=1, cmap='RdBu')