Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shap.prep.stack.data uses internal ID as grouping variable for hierarchical clustering in addition to shap values #31

Open
kransom14 opened this issue Jun 2, 2022 · 0 comments

Comments

@kransom14
Copy link

In an attempt to understand how the groups were created in shap.prep.stack.data(), I attempted to reproduce the grouping the function calculates on my own with the stats::hclust() function. stats::hclust() and cutree() give a different number of samples in each group when passed the shap values from a model. However, the stats::hclust() and cutree() functions will give the same number of samples per group as shap.stack.prep.data() when a column is added for sequential row ID and included as a grouping variable. Please see below for reproducible example.

library(SHAPforxgboost)
library(xgboost)
library(DALEX)
library(caret)

# use apartments data from DALEX
data("apartments")
head(apartments)
dummy <- dummyVars(" ~ .", data=apartments)
final_df <- data.frame(predict(dummy, newdata=apartments))
head(final_df)
X1 = as.matrix(final_df[,-1])
mod1 = xgboost::xgboost(
  data = X1, label = apartments$m2.price, gamma = 0, eta = 1,
  lambda = 0, nrounds = 1, verbose = FALSE)

shap_values <- shap.values(xgb_model = mod1, X_train = X1)
shap_values$mean_shap_score
shap_values_appts <- shap_values$shap_score

plot_data <- shap.prep.stack.data(shap_contrib = shap_values_appts,
                                  n_groups = 4)
summary(as.factor(plot_data$group))
#1   2   3   4 
#606  92 215  87 

# calculate clusters with hclust() as is done internally to shap.prep.stack.data
# include the scaling that shap.prep.stack.data performs
h <- hclust(dist(scale(shap_values_appts)), method = "ward.D")
groups <- cutree(h, 4)
summary(as.factor(groups))
#   2   3   4 
#307 336 270  8

# add row ID column to shap values data frame and recalculate
# the number of samples in each group will reproduce (groups identities are just shuffled)
shap_values_appts_id <- shap_values_appts
shap_values_appts_id$ID <- seq(1, nrow(shap_values_appts_id))

h2 <- hclust(dist(scale(shap_values_appts_id)), method = "ward.D")
groups2 <- cutree(h2, 4)
summary(as.factor(groups2))
#1   2   3   4 
#215 606  87  92 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant