From 967b853b02b3b26f524864520ef1dd68b3d53d4d Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Tue, 7 Feb 2023 18:50:07 +1100 Subject: [PATCH] Implement utilities for SHAP values in regression --- yli/regress.py | 21 ++++++++++++++++++ yli/shap.py | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 yli/shap.py diff --git a/yli/regress.py b/yli/regress.py index ec79cc6..e770380 100644 --- a/yli/regress.py +++ b/yli/regress.py @@ -32,6 +32,7 @@ import weakref from .bayes_factors import BayesFactor, bayesfactor_afbf from .config import config +from .shap import ShapResult from .sig_tests import ChiSquaredResult, FTestResult from .utils import Estimate, PValueStyle, as_numeric, check_nan, cols_for_formula, convert_pandas_nullable, fmt_p, formula_factor_ref_category, parse_patsy_term @@ -460,6 +461,26 @@ class RegressionResult: self.exp ) + def shap(self, **kwargs): + # TODO: Documentation + + import shap + + xdata = ShapResult._get_xdata(self) + + # Combine terms into single list + params = [] + for term in self.terms.values(): + if isinstance(term, SingleTerm): + params.append(term.beta.point) + else: + params.extend(s.beta.point for s in term.categories.values()) + + explainer = shap.LinearExplainer((np.array(params[1:]), params[0]), xdata, **kwargs) # FIXME: Assumes zeroth term is intercept + shap_values = explainer.shap_values(xdata).astype('float') + + return ShapResult(weakref.ref(self), shap_values, list(xdata.columns)) + def _header_table(self, html): """Return the entries for the header table""" diff --git a/yli/shap.py b/yli/shap.py new file mode 100644 index 0000000..96692ee --- /dev/null +++ b/yli/shap.py @@ -0,0 +1,60 @@ +import pandas as pd +import patsy + +from .utils import as_numeric, check_nan, cols_for_formula, convert_pandas_nullable + +class ShapResult: + # TODO: Documentation + + def __init__(self, model, shap_values, features): + self.model = model + self.shap_values = shap_values + self.features = features + + @staticmethod + def _get_xdata(model): + df = model.df() + if df is None: + raise Exception('Referenced DataFrame has been dropped') + dep = model.dep + + # Check for/clean NaNs + # NaN warning/error will already have been handled in regress, so here we pass nan_policy='omit' + # Following this, we pass nan_policy='raise' to assert no NaNs remaining + df = df[[dep] + cols_for_formula(model.formula, df)] + df = check_nan(df, 'omit') + + # Ensure numeric type for dependent variable + df[dep], dep_categories = as_numeric(df[dep]) + + # Convert pandas nullable types for independent variables as this breaks statsmodels + df = convert_pandas_nullable(df) + + # Get xdata for SHAP + dmatrix = patsy.dmatrix(model.formula, df, return_type='dataframe') + xdata = dmatrix.iloc[:, 1:] # FIXME: Assumes zeroth term is intercept + + return xdata + + def mean(self): + return pd.Series(abs(self.shap_values).mean(axis=0), index=self.features) + + def plot(self, **kwargs): + import matplotlib.pyplot as plt + import shap + + model = self.model() + if model is None: + raise Exception('Referenced RegressionResult has been dropped') + + xdata = self._get_xdata(model) + + shap.summary_plot(self.shap_values, xdata, show=False, axis_color='black', **kwargs) # pass show=False to get gcf/gca + + # Fix colour bar + # https://stackoverflow.com/questions/70461753/shap-the-color-bar-is-not-displayed-in-the-summary-plot + ax_colorbar = plt.gcf().axes[-1] + ax_colorbar.set_aspect('auto') + ax_colorbar.set_box_aspect(50) + + return plt.gcf(), plt.gca()