scipy-yli/yli/shap.py

79 lines
2.1 KiB
Python
Raw Normal View History

import pandas as pd
import patsy
from .utils import as_numeric, check_nan, cols_for_formula, convert_pandas_nullable
class ShapResult:
2023-02-25 23:46:48 +11:00
"""
SHAP values for a regression model
2023-04-16 23:52:12 +10:00
See :meth:`yli.regress.RegressionModel.shap`.
2023-02-25 23:46:48 +11:00
"""
def __init__(self, model, shap_values, features):
self.model = model
self.shap_values = shap_values
self.features = features
@staticmethod
def _get_xdata(model):
df = model.df()
if df is None:
raise Exception('Referenced DataFrame has been dropped')
dep = model.dep
# Check for/clean NaNs
# NaN warning/error will already have been handled in regress, so here we pass nan_policy='omit'
# Following this, we pass nan_policy='raise' to assert no NaNs remaining
df = df[[dep] + cols_for_formula(model.formula, df)]
df = check_nan(df, 'omit')
# Ensure numeric type for dependent variable
df[dep], dep_categories = as_numeric(df[dep])
# Convert pandas nullable types for independent variables as this breaks statsmodels
df = convert_pandas_nullable(df)
# Get xdata for SHAP
dmatrix = patsy.dmatrix(model.formula, df, return_type='dataframe')
xdata = dmatrix.iloc[:, 1:] # FIXME: Assumes zeroth term is intercept
return xdata
def mean(self):
2023-02-25 23:46:48 +11:00
"""
Compute the mean absolute SHAP value for each parameter
:rtype: Series
"""
return pd.Series(abs(self.shap_values).mean(axis=0), index=self.features)
def plot(self, **kwargs):
2023-02-25 23:46:48 +11:00
"""
Generate a scatterplot of the SHAP values
Uses the Python *matplotlib* library.
:rtype: (Figure, Axes)
"""
import matplotlib.pyplot as plt
import shap
model = self.model()
if model is None:
2023-04-16 23:52:12 +10:00
raise Exception('Referenced RegressionModel has been dropped')
xdata = self._get_xdata(model)
shap.summary_plot(self.shap_values, xdata, show=False, axis_color='black', **kwargs) # pass show=False to get gcf/gca
# Fix colour bar
# https://stackoverflow.com/questions/70461753/shap-the-color-bar-is-not-displayed-in-the-summary-plot
ax_colorbar = plt.gcf().axes[-1]
ax_colorbar.set_aspect('auto')
ax_colorbar.set_box_aspect(50)
return plt.gcf(), plt.gca()