scipy-yli/yli/shap.py

88 lines
2.6 KiB
Python

# scipy-yli: Helpful SciPy utilities and recipes
# Copyright © 2022–2025 Lee Yingtong Li (RunasSudo)
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import pandas as pd
import patsy
from .utils import as_numeric, check_nan, cols_for_formula, convert_pandas_nullable
class ShapResult:
"""
SHAP values for a regression model
See :meth:`yli.regress.RegressionModel.shap`.
"""
def __init__(self, model, explanation):
self.model = model
self.explanation = explanation
@staticmethod
def _get_xdata(model):
df = model.df()
if df is None:
raise Exception('Referenced DataFrame has been dropped')
dep = model.dep
# Check for/clean NaNs
# NaN warning/error will already have been handled in regress, so here we pass nan_policy='omit'
# Following this, we pass nan_policy='raise' to assert no NaNs remaining
df = df[[dep] + cols_for_formula(model.formula, df)]
df = check_nan(df, 'omit')
# Ensure numeric type for dependent variable
# TODO: Is this step necessary?
df[dep], dep_categories = as_numeric(df[dep])
# Convert pandas nullable types for independent variables as this breaks statsmodels
df = convert_pandas_nullable(df)
# Get xdata for SHAP
dmatrix = patsy.dmatrix(model.formula, df, return_type='dataframe')
xdata = dmatrix.iloc[:, 1:] # FIXME: Assumes zeroth term is intercept
return xdata
def mean(self):
"""
Compute the mean absolute SHAP value for each parameter
:rtype: Series
"""
raise NotImplementedError()
#return pd.Series(abs(self.shap_values).mean(axis=0), index=self.features)
def plot(self, **kwargs):
"""
Generate a scatterplot of the SHAP values
Uses the Python *matplotlib* library.
:rtype: (Figure, Axes)
"""
import matplotlib.pyplot as plt
import shap
model = self.model()
if model is None:
raise Exception('Referenced RegressionModel has been dropped')
shap.plots.beeswarm(self.explanation, show=False, axis_color='black', max_display=None, **kwargs) # pass show=False to get gcf/gca
return plt.gcf(), plt.gca()