diff --git a/yli/regress.py b/yli/regress.py index 9290335..f0ae850 100644 --- a/yli/regress.py +++ b/yli/regress.py @@ -663,9 +663,9 @@ class RegressionModel: params.extend(s.beta.point for s in term.categories.values()) explainer = shap.LinearExplainer((np.array(params[1:]), params[0]), xdata, **kwargs) # FIXME: Assumes zeroth term is intercept - shap_values = explainer.shap_values(xdata).astype('float') + explanation = explainer(xdata) - return ShapResult(weakref.ref(self), shap_values, list(xdata.columns)) + return ShapResult(weakref.ref(self), explanation) class LikelihoodRatioTestResult(ChiSquaredResult): """ diff --git a/yli/shap.py b/yli/shap.py index 6119f46..41ef085 100644 --- a/yli/shap.py +++ b/yli/shap.py @@ -26,10 +26,9 @@ class ShapResult: See :meth:`yli.regress.RegressionModel.shap`. """ - def __init__(self, model, shap_values, features): + def __init__(self, model, explanation): self.model = model - self.shap_values = shap_values - self.features = features + self.explanation = explanation @staticmethod def _get_xdata(model): @@ -64,7 +63,8 @@ class ShapResult: :rtype: Series """ - return pd.Series(abs(self.shap_values).mean(axis=0), index=self.features) + raise NotImplementedError() + #return pd.Series(abs(self.shap_values).mean(axis=0), index=self.features) def plot(self, **kwargs): """ @@ -82,14 +82,6 @@ class ShapResult: if model is None: raise Exception('Referenced RegressionModel has been dropped') - xdata = self._get_xdata(model) - - shap.summary_plot(self.shap_values, xdata, show=False, axis_color='black', **kwargs) # pass show=False to get gcf/gca - - # Fix colour bar - # https://stackoverflow.com/questions/70461753/shap-the-color-bar-is-not-displayed-in-the-summary-plot - ax_colorbar = plt.gcf().axes[-1] - ax_colorbar.set_aspect('auto') - ax_colorbar.set_box_aspect(50) + shap.plots.beeswarm(self.explanation, show=False, axis_color='black', max_display=None, **kwargs) # pass show=False to get gcf/gca return plt.gcf(), plt.gca()