Use shap "new" Explanation API
This commit is contained in:
parent
57e472ca09
commit
ed65dc3e8f
@ -663,9 +663,9 @@ class RegressionModel:
|
|||||||
params.extend(s.beta.point for s in term.categories.values())
|
params.extend(s.beta.point for s in term.categories.values())
|
||||||
|
|
||||||
explainer = shap.LinearExplainer((np.array(params[1:]), params[0]), xdata, **kwargs) # FIXME: Assumes zeroth term is intercept
|
explainer = shap.LinearExplainer((np.array(params[1:]), params[0]), xdata, **kwargs) # FIXME: Assumes zeroth term is intercept
|
||||||
shap_values = explainer.shap_values(xdata).astype('float')
|
explanation = explainer(xdata)
|
||||||
|
|
||||||
return ShapResult(weakref.ref(self), shap_values, list(xdata.columns))
|
return ShapResult(weakref.ref(self), explanation)
|
||||||
|
|
||||||
class LikelihoodRatioTestResult(ChiSquaredResult):
|
class LikelihoodRatioTestResult(ChiSquaredResult):
|
||||||
"""
|
"""
|
||||||
|
18
yli/shap.py
18
yli/shap.py
@ -26,10 +26,9 @@ class ShapResult:
|
|||||||
See :meth:`yli.regress.RegressionModel.shap`.
|
See :meth:`yli.regress.RegressionModel.shap`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, shap_values, features):
|
def __init__(self, model, explanation):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.shap_values = shap_values
|
self.explanation = explanation
|
||||||
self.features = features
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_xdata(model):
|
def _get_xdata(model):
|
||||||
@ -64,7 +63,8 @@ class ShapResult:
|
|||||||
:rtype: Series
|
:rtype: Series
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return pd.Series(abs(self.shap_values).mean(axis=0), index=self.features)
|
raise NotImplementedError()
|
||||||
|
#return pd.Series(abs(self.shap_values).mean(axis=0), index=self.features)
|
||||||
|
|
||||||
def plot(self, **kwargs):
|
def plot(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -82,14 +82,6 @@ class ShapResult:
|
|||||||
if model is None:
|
if model is None:
|
||||||
raise Exception('Referenced RegressionModel has been dropped')
|
raise Exception('Referenced RegressionModel has been dropped')
|
||||||
|
|
||||||
xdata = self._get_xdata(model)
|
shap.plots.beeswarm(self.explanation, show=False, axis_color='black', max_display=None, **kwargs) # pass show=False to get gcf/gca
|
||||||
|
|
||||||
shap.summary_plot(self.shap_values, xdata, show=False, axis_color='black', **kwargs) # pass show=False to get gcf/gca
|
|
||||||
|
|
||||||
# Fix colour bar
|
|
||||||
# https://stackoverflow.com/questions/70461753/shap-the-color-bar-is-not-displayed-in-the-summary-plot
|
|
||||||
ax_colorbar = plt.gcf().axes[-1]
|
|
||||||
ax_colorbar.set_aspect('auto')
|
|
||||||
ax_colorbar.set_box_aspect(50)
|
|
||||||
|
|
||||||
return plt.gcf(), plt.gca()
|
return plt.gcf(), plt.gca()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user