Implement utilities for SHAP values in regression
This commit is contained in:
parent
dbebc3b8e9
commit
967b853b02
@ -32,6 +32,7 @@ import weakref
|
|||||||
|
|
||||||
from .bayes_factors import BayesFactor, bayesfactor_afbf
|
from .bayes_factors import BayesFactor, bayesfactor_afbf
|
||||||
from .config import config
|
from .config import config
|
||||||
|
from .shap import ShapResult
|
||||||
from .sig_tests import ChiSquaredResult, FTestResult
|
from .sig_tests import ChiSquaredResult, FTestResult
|
||||||
from .utils import Estimate, PValueStyle, as_numeric, check_nan, cols_for_formula, convert_pandas_nullable, fmt_p, formula_factor_ref_category, parse_patsy_term
|
from .utils import Estimate, PValueStyle, as_numeric, check_nan, cols_for_formula, convert_pandas_nullable, fmt_p, formula_factor_ref_category, parse_patsy_term
|
||||||
|
|
||||||
@ -460,6 +461,26 @@ class RegressionResult:
|
|||||||
self.exp
|
self.exp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def shap(self, **kwargs):
|
||||||
|
# TODO: Documentation
|
||||||
|
|
||||||
|
import shap
|
||||||
|
|
||||||
|
xdata = ShapResult._get_xdata(self)
|
||||||
|
|
||||||
|
# Combine terms into single list
|
||||||
|
params = []
|
||||||
|
for term in self.terms.values():
|
||||||
|
if isinstance(term, SingleTerm):
|
||||||
|
params.append(term.beta.point)
|
||||||
|
else:
|
||||||
|
params.extend(s.beta.point for s in term.categories.values())
|
||||||
|
|
||||||
|
explainer = shap.LinearExplainer((np.array(params[1:]), params[0]), xdata, **kwargs) # FIXME: Assumes zeroth term is intercept
|
||||||
|
shap_values = explainer.shap_values(xdata).astype('float')
|
||||||
|
|
||||||
|
return ShapResult(weakref.ref(self), shap_values, list(xdata.columns))
|
||||||
|
|
||||||
def _header_table(self, html):
|
def _header_table(self, html):
|
||||||
"""Return the entries for the header table"""
|
"""Return the entries for the header table"""
|
||||||
|
|
||||||
|
60
yli/shap.py
Normal file
60
yli/shap.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import patsy
|
||||||
|
|
||||||
|
from .utils import as_numeric, check_nan, cols_for_formula, convert_pandas_nullable
|
||||||
|
|
||||||
|
class ShapResult:
|
||||||
|
# TODO: Documentation
|
||||||
|
|
||||||
|
def __init__(self, model, shap_values, features):
|
||||||
|
self.model = model
|
||||||
|
self.shap_values = shap_values
|
||||||
|
self.features = features
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_xdata(model):
|
||||||
|
df = model.df()
|
||||||
|
if df is None:
|
||||||
|
raise Exception('Referenced DataFrame has been dropped')
|
||||||
|
dep = model.dep
|
||||||
|
|
||||||
|
# Check for/clean NaNs
|
||||||
|
# NaN warning/error will already have been handled in regress, so here we pass nan_policy='omit'
|
||||||
|
# Following this, we pass nan_policy='raise' to assert no NaNs remaining
|
||||||
|
df = df[[dep] + cols_for_formula(model.formula, df)]
|
||||||
|
df = check_nan(df, 'omit')
|
||||||
|
|
||||||
|
# Ensure numeric type for dependent variable
|
||||||
|
df[dep], dep_categories = as_numeric(df[dep])
|
||||||
|
|
||||||
|
# Convert pandas nullable types for independent variables as this breaks statsmodels
|
||||||
|
df = convert_pandas_nullable(df)
|
||||||
|
|
||||||
|
# Get xdata for SHAP
|
||||||
|
dmatrix = patsy.dmatrix(model.formula, df, return_type='dataframe')
|
||||||
|
xdata = dmatrix.iloc[:, 1:] # FIXME: Assumes zeroth term is intercept
|
||||||
|
|
||||||
|
return xdata
|
||||||
|
|
||||||
|
def mean(self):
|
||||||
|
return pd.Series(abs(self.shap_values).mean(axis=0), index=self.features)
|
||||||
|
|
||||||
|
def plot(self, **kwargs):
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import shap
|
||||||
|
|
||||||
|
model = self.model()
|
||||||
|
if model is None:
|
||||||
|
raise Exception('Referenced RegressionResult has been dropped')
|
||||||
|
|
||||||
|
xdata = self._get_xdata(model)
|
||||||
|
|
||||||
|
shap.summary_plot(self.shap_values, xdata, show=False, axis_color='black', **kwargs) # pass show=False to get gcf/gca
|
||||||
|
|
||||||
|
# Fix colour bar
|
||||||
|
# https://stackoverflow.com/questions/70461753/shap-the-color-bar-is-not-displayed-in-the-summary-plot
|
||||||
|
ax_colorbar = plt.gcf().axes[-1]
|
||||||
|
ax_colorbar.set_aspect('auto')
|
||||||
|
ax_colorbar.set_box_aspect(50)
|
||||||
|
|
||||||
|
return plt.gcf(), plt.gca()
|
Loading…
Reference in New Issue
Block a user