import pandas as pd import patsy from .utils import as_numeric, check_nan, cols_for_formula, convert_pandas_nullable class ShapResult: """ SHAP values for a regression model See :meth:`yli.regress.RegressionModel.shap`. """ def __init__(self, model, shap_values, features): self.model = model self.shap_values = shap_values self.features = features @staticmethod def _get_xdata(model): df = model.df() if df is None: raise Exception('Referenced DataFrame has been dropped') dep = model.dep # Check for/clean NaNs # NaN warning/error will already have been handled in regress, so here we pass nan_policy='omit' # Following this, we pass nan_policy='raise' to assert no NaNs remaining df = df[[dep] + cols_for_formula(model.formula, df)] df = check_nan(df, 'omit') # Ensure numeric type for dependent variable df[dep], dep_categories = as_numeric(df[dep]) # Convert pandas nullable types for independent variables as this breaks statsmodels df = convert_pandas_nullable(df) # Get xdata for SHAP dmatrix = patsy.dmatrix(model.formula, df, return_type='dataframe') xdata = dmatrix.iloc[:, 1:] # FIXME: Assumes zeroth term is intercept return xdata def mean(self): """ Compute the mean absolute SHAP value for each parameter :rtype: Series """ return pd.Series(abs(self.shap_values).mean(axis=0), index=self.features) def plot(self, **kwargs): """ Generate a scatterplot of the SHAP values Uses the Python *matplotlib* library. :rtype: (Figure, Axes) """ import matplotlib.pyplot as plt import shap model = self.model() if model is None: raise Exception('Referenced RegressionModel has been dropped') xdata = self._get_xdata(model) shap.summary_plot(self.shap_values, xdata, show=False, axis_color='black', **kwargs) # pass show=False to get gcf/gca # Fix colour bar # https://stackoverflow.com/questions/70461753/shap-the-color-bar-is-not-displayed-in-the-summary-plot ax_colorbar = plt.gcf().axes[-1] ax_colorbar.set_aspect('auto') ax_colorbar.set_box_aspect(50) return plt.gcf(), plt.gca()