Skip to content

Commit

Permalink
changes for case weights and tidymodels/censored#163
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Apr 4, 2022
1 parent 0e9f4ba commit 152d138
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 15 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,4 @@ Config/rcmdcheck/ignore-inconsequential-notes: true
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.1.2
RoxygenNote: 7.1.2.9000
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# parsnip (development version)

* `xgb_train()` now allows for case weights

# parsnip 0.2.1

* Fixed a major bug in spark models induced in the previous version (#671).
Expand Down
48 changes: 34 additions & 14 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ check_args.boost_tree <- function(object) {
invisible(object)
}


# xgboost helpers --------------------------------------------------------------

#' Boosted trees via xgboost
Expand Down Expand Up @@ -256,11 +257,11 @@ check_args.boost_tree <- function(object) {
#' @keywords internal
#' @export
xgb_train <- function(
x, y,
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bynode = NULL,
colsample_bytree = NULL, min_child_weight = 1, gamma = 0, subsample = 1,
validation = 0, early_stop = NULL, objective = NULL, counts = TRUE,
event_level = c("first", "second"), ...) {
x, y,
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bynode = NULL,
colsample_bytree = NULL, min_child_weight = 1, gamma = 0, subsample = 1,
validation = 0, early_stop = NULL, objective = NULL, counts = TRUE,
event_level = c("first", "second"), weights = NULL, ...) {

event_level <- rlang::arg_match(event_level, c("first", "second"))
others <- list(...)
Expand Down Expand Up @@ -295,7 +296,11 @@ xgb_train <- function(
n <- nrow(x)
p <- ncol(x)

x <- as_xgb_data(x, y, validation, event_level)
x <-
as_xgb_data(x, y,
validation = validation,
event_level = event_level,
weights = weights)


if (!is.numeric(subsample) || subsample < 0 || subsample > 1) {
Expand Down Expand Up @@ -401,7 +406,7 @@ xgb_pred <- function(object, newdata, ...) {
}


as_xgb_data <- function(x, y, validation = 0, event_level = "first", ...) {
as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "first", ...) {
lvls <- levels(y)
n <- nrow(x)

Expand All @@ -424,22 +429,36 @@ as_xgb_data <- function(x, y, validation = 0, event_level = "first", ...) {

if (!inherits(x, "xgb.DMatrix")) {
if (validation > 0) {
# Split data
m <- floor(n * (1 - validation)) + 1
trn_index <- sample(1:n, size = max(m, 2))
wlist <-
list(validation = xgboost::xgb.DMatrix(x[-trn_index, ], label = y[-trn_index], missing = NA))
dat <- xgboost::xgb.DMatrix(x[trn_index, ], label = y[trn_index], missing = NA)
val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], missing = NA)
watch_list <- list(validation = val_data)

info_list <- list(label = y[trn_index])
if (!is.null(weights)) {
info_list$weight <- weights[trn_index]
}
dat <- xgboost::xgb.DMatrix(x[trn_index,], missing = NA, info = info_list)


} else {
dat <- xgboost::xgb.DMatrix(x, label = y, missing = NA)
wlist <- list(training = dat)
info_list <- list(label = y)
if (!is.null(weights)) {
info_list$weight <- weights
}
dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list)
watch_list <- list(training = dat)
}
} else {
dat <- xgboost::setinfo(x, "label", y)
wlist <- list(training = dat)
if (!is.null(weights)) {
dat <- xgboost::setinfo(x, "weight", weights)
}
watch_list <- list(training = dat)
}

list(data = dat, watchlist = wlist)
list(data = dat, watchlist = watch_list)
}

get_event_level <- function(model_spec){
Expand All @@ -452,6 +471,7 @@ get_event_level <- function(model_spec){
event_level
}


#' @export
#' @rdname multi_predict
#' @param trees An integer vector for the number of trees in the ensemble.
Expand Down
1 change: 1 addition & 0 deletions man/xgb_train.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 152d138

Please sign in to comment.