Use shap "new" Explanation API

This commit is contained in:
RunasSudo 2025-01-29 00:59:25 +11:00
parent 57e472ca09
commit ed65dc3e8f
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
2 changed files with 7 additions and 15 deletions

View File

@ -663,9 +663,9 @@ class RegressionModel:
params.extend(s.beta.point for s in term.categories.values()) params.extend(s.beta.point for s in term.categories.values())
explainer = shap.LinearExplainer((np.array(params[1:]), params[0]), xdata, **kwargs) # FIXME: Assumes zeroth term is intercept explainer = shap.LinearExplainer((np.array(params[1:]), params[0]), xdata, **kwargs) # FIXME: Assumes zeroth term is intercept
shap_values = explainer.shap_values(xdata).astype('float') explanation = explainer(xdata)
return ShapResult(weakref.ref(self), shap_values, list(xdata.columns)) return ShapResult(weakref.ref(self), explanation)
class LikelihoodRatioTestResult(ChiSquaredResult): class LikelihoodRatioTestResult(ChiSquaredResult):
""" """

View File

@ -26,10 +26,9 @@ class ShapResult:
See :meth:`yli.regress.RegressionModel.shap`. See :meth:`yli.regress.RegressionModel.shap`.
""" """
def __init__(self, model, shap_values, features): def __init__(self, model, explanation):
self.model = model self.model = model
self.shap_values = shap_values self.explanation = explanation
self.features = features
@staticmethod @staticmethod
def _get_xdata(model): def _get_xdata(model):
@ -64,7 +63,8 @@ class ShapResult:
:rtype: Series :rtype: Series
""" """
return pd.Series(abs(self.shap_values).mean(axis=0), index=self.features) raise NotImplementedError()
#return pd.Series(abs(self.shap_values).mean(axis=0), index=self.features)
def plot(self, **kwargs): def plot(self, **kwargs):
""" """
@ -82,14 +82,6 @@ class ShapResult:
if model is None: if model is None:
raise Exception('Referenced RegressionModel has been dropped') raise Exception('Referenced RegressionModel has been dropped')
xdata = self._get_xdata(model) shap.plots.beeswarm(self.explanation, show=False, axis_color='black', max_display=None, **kwargs) # pass show=False to get gcf/gca
shap.summary_plot(self.shap_values, xdata, show=False, axis_color='black', **kwargs) # pass show=False to get gcf/gca
# Fix colour bar
# https://stackoverflow.com/questions/70461753/shap-the-color-bar-is-not-displayed-in-the-summary-plot
ax_colorbar = plt.gcf().axes[-1]
ax_colorbar.set_aspect('auto')
ax_colorbar.set_box_aspect(50)
return plt.gcf(), plt.gca() return plt.gcf(), plt.gca()