Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xgboost implement Aft and Cox with weights, untested need support for testing #163

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

brunocarlin
Copy link

@brunocarlin brunocarlin commented Mar 16, 2022

Very messy start to the functionality, I couldn't remove this gitnore sorry, also we will eventually have to implement survival:cox but that returns hazard ratios instead of time so I don't know how to separate them, it was surprisingly easy to add this one I just need to decide how to stop if the user tries to predict the wrong type given the objective he has chosen here is a reprex of it working

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(censored)
#> Loading required package: survival
library(survival)
tidymodels_prefer()

data(cancer)

lung <- lung %>% drop_na()
lung_train <- lung[-c(1:5), ]
lung_test <- lung[1:5, ]


test <-
  boost_tree()|> set_engine('xgboost') |> set_mode('censored regression')

test |> 
  translate()
#> Boosted Tree Model Specification (censored regression)
#> 
#> Computational engine: xgboost 
#> 
#> Model fit template:
#> censored::xgb_train_censored(x = missing_arg(), y = missing_arg(), 
#>     nthread = 1, verbose = 0)

set.seed(1)
bt_fit <- test %>% fit(Surv(time, status) ~ ., data = lung_train)
bt_fit
#> parsnip model object
#> 
#> ##### xgb.Booster
#> raw: 45.4 Kb 
#> call:
#>   xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0, 
#>     colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1, 
#>     subsample = 1, objective = "survival:aft"), data = x$data, 
#>     nrounds = 15, watchlist = x$watchlist, verbose = 0, nthread = 1)
#> params (as set within xgb.train):
#>   eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", objective = "survival:aft", nthread = "1", validate_parameters = "TRUE"
#> xgb.attributes:
#>   niter
#> callbacks:
#>   cb.evaluation.log()
#> # of features: 8 
#> niter: 15
#> nfeatures : 8 
#> evaluation_log:
#>     iter training_aft_nloglik
#>        1            14.698676
#>        2             9.883017
#> ---                          
#>       14             4.564666
#>       15             4.553461

predict(
  bt_fit, 
  lung_train,
  type = "time",
)
#> # A tibble: 162 x 1
#>    .pred_time
#>         <dbl>
#>  1      231. 
#>  2      182. 
#>  3      283. 
#>  4      504. 
#>  5      423. 
#>  6      423. 
#>  7       57.0
#>  8      327. 
#>  9       94.3
#> 10      369. 
#> # ... with 152 more rows



dtrain <- lung_train |> select(-time,-status)

xgb_train <- xgboost::xgb.DMatrix(dtrain |> as.matrix())

xgboost::setinfo(xgb_train,'label_lower_bound',lung_train$time)
#> [1] TRUE

upper <- lung_train |> transmute(if_else(status == 2,time,+Inf)) |> pull()

xgboost::setinfo(xgb_train,'label_upper_bound',upper)
#> [1] TRUE

params <- list(objective='survival:aft',
               eval_metric='aft-nloglik')


bst = xgboost::xgb.train(params,xgb_train,nrounds = 15)

predict(bst,xgb_train)
#>   [1]  231.001297  181.531784  282.884491  503.601868  422.816559  422.607727
#>   [7]   56.973206  326.810364   94.282822  368.609589  460.003082  576.603760
#>  [13]  147.651245  316.400269   12.624409  563.194641   26.933863  110.736954
#>  [19]   62.588371  673.044067 1159.743286  100.161781  651.765503  273.160431
#>  [25]  174.070190  448.462006  303.948975   92.200394  348.106537  346.903015
#>  [31]  564.314819  692.380981  109.115311  412.932831  400.043335    7.912508
#>  [37]  471.043793  292.577637  514.660706  228.112335   72.993294  187.698166
#>  [43]   94.267357 1008.859192  342.155762  223.576965  619.221436  288.998474
#>  [49]   14.276150  326.450897  409.521088  642.956665  374.180908  181.168594
#>  [55]  430.338196  124.323845  744.351624  173.072540 1365.703857  285.477417
#>  [61]  616.979126  228.811066 1280.812744  138.947388  521.381287  192.222610
#>  [67]  218.413223   31.944872  451.967926 1707.287720  454.458130  175.718262
#>  [73] 1395.177002  335.843140  332.520386   48.389252  210.697769  272.518372
#>  [79]   16.216888  196.842133  527.387573  279.509064  291.177765  220.606888
#>  [85]  534.116211   58.688663  482.041168  207.826431  134.650085  147.338837
#>  [91] 1151.077637  305.285736  349.505981  208.674271  922.822937  416.841003
#>  [97]  341.617645  187.758224  779.465515  242.261658   83.590912  868.222473
#> [103]  262.357697  233.360672   32.361862  215.750290  150.604263  278.074127
#> [109]  174.933243  768.926025  933.004395  254.248947  708.768555  293.480804
#> [115]  617.796387  898.248718  347.187836  217.031403  371.995361  583.900146
#> [121]  466.231567  206.762955 1251.060913  416.042511 1289.109497  241.221222
#> [127]  205.523346  883.585449  566.430298  900.743713  936.729309 1469.551025
#> [133]  265.212311  464.018982  163.037277  321.616638  257.634216  589.192383
#> [139]  313.883698  441.816956  918.941223  911.709961  171.871414  128.250366
#> [145]   72.955261 1630.290405 1294.991943  845.519836  221.963730  733.295654
#> [151]  521.173767  461.724884  446.850403  162.951401  449.657074  666.839661
#> [157]  366.623871  586.267395 1189.139893  663.624451  508.253052  720.888794

Created on 2022-03-15 by the reprex package (v2.0.1)

@brunocarlin
Copy link
Author

#161 I will try to implement on this branch

@brunocarlin
Copy link
Author

We also may need to create a dummy metric for censored regression models in order to leverage the stacks framework, or just reimplement it in censored knowing that the components won't be in individually tuned but the stacked predictor can be tuned because it can calculate some metrics like c-index and brier score, something the base models can't do

@brunocarlin
Copy link
Author

I think I got close to what I wanted to achieve, maybe we can create a simpler blend_predictions for survival models because as the documentation clearly show not all models will be able to produce all kinds of prediction types, but we can stack them as long as there are non time dependent predictions.

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(censored)
#> Loading required package: survival
library(survival)
tidymodels_prefer()

data(cancer)

lung <- lung %>% drop_na()
lung_train <- lung[-c(1:5), ]
lung_test <- lung[1:5, ]


mod_spec1 <- 
  decision_tree() %>%
  set_mode("censored regression")

mod_spec2 <- 
  boost_tree(trees = 100) %>%
  set_mode("censored regression") |> 
  set_engine('xgboost',aft_loss_distribution_scale = 1)

models <- list(mod_spec1,mod_spec2)


fits <- models |> map(fit,data = lung_train,formula = Surv(time, status) ~ .)

table_specs <- tibble(fits)

table_predicts <- table_specs |> 
  rowwise() |> 
  mutate(predicts = list(predict(fits, lung_train)))

train_predictions_models <- table_predicts |>
  ungroup() |>
  select(predicts) |>
  mutate(model_number = stringr::str_c('model_', row_number())) |>
  unnest(cols = predicts) |>
  pivot_wider(
    names_from = model_number,
    values_from = -model_number,
    values_fn = list
  ) |> unnest(everything())

# create the train stacked model train

stacked_data <- lung_train |> 
  select(time,status) |>
  bind_cols(train_predictions_models)
  
  

spec_stack <-
  proportional_hazards(penalty = .1) |>
  set_engine("glmnet",
             lower.limits = 0, lambda.min.ratio = 0)



fit_stacked <- spec_stack |> fit(Surv(time, status) ~ .,data = stacked_data)

fit_stacked |> predict(stacked_data,type = 'survival',time = c(1,10,100,200)) |> 
  slice(119) |> 
  tidyr::unnest(col = .pred)
#> # A tibble: 4 x 2
#>   .time .pred_survival
#>   <dbl>          <dbl>
#> 1     1          1    
#> 2    10          0.995
#> 3   100          0.858
#> 4   200          0.689

table_predicts_testing <- table_predicts |> 
  mutate(predict_testing = list(predict(fits, lung_test)))


test_predictions_models <- table_predicts_testing |>
  ungroup() |>
  select(predict_testing) |>
  mutate(model_number = stringr::str_c('model_', row_number())) |>
  unnest(cols = predict_testing ) |>
  pivot_wider(
    names_from = model_number,
    values_from = -model_number,
    values_fn = list
  ) |> unnest(everything())

stacked_data_testing <- lung_test |> 
  select(time,status) |>
  bind_cols(test_predictions_models)

fit_stacked |> predict(stacked_data_testing,type = 'survival',time = c(1,10,100,200)) |> 
  slice(1:5) |> 
  tidyr::unnest(col = .pred)
#> # A tibble: 20 x 2
#>    .time .pred_survival
#>    <dbl>          <dbl>
#>  1     1          1    
#>  2    10          0.995
#>  3   100          0.858
#>  4   200          0.689
#>  5     1          1    
#>  6    10          0.993
#>  7   100          0.803
#>  8   200          0.587
#>  9     1          1    
#> 10    10          0.997
#> 11   100          0.910
#> 12   200          0.795
#> 13     1          1    
#> 14    10          0.994
#> 15   100          0.835
#> 16   200          0.646
#> 17     1          1    
#> 18    10          0.994
#> 19   100          0.835
#> 20   200          0.646

Created on 2022-03-17 by the reprex package (v2.0.1)

Copy link
Member

@hfrick hfrick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! I left some comments but generally I think we should export/use some stuff from parsnip so that we don't duplicate efforts.

Can you elaborate a bit what you mean with your comment about the hazard? And can you add tests for the new functionality as well, please?

eng = "xgboost",
mode = "censored regression",
options = list(
predictor_indicators = "none",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should leave this as "one_hot" like in parsnip

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cox objective from xgboost says:

survival:cox: Cox regression for right censored survival time data (negative values are considered right censored). Note that predictions are returned on the hazard ratio scale (i.e., as HR = exp(marginal_prediction) in the proportional hazard function h(t) = h0(t) * HR).

But the aft objective says

survival:aft: Accelerated failure time model for censored survival time data. See Survival Analysis with Accelerated Failure Time for details. and returns time




xgb_pred <- function(object, newdata, ...) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make a PR to parsnip to export xgb_pred() from there so we can use it here and don't need a copy

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is now exported in the dev version of parsnip as xgb_predict()




as_xgb_data <- function(x, y, validation = 0, ...) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to develop a version that works for all modes here, and when it's ready, open a PR on parsnip. Could you extend the parsnip version?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe... we would need to check the current objective, that was my initial idea to implement aft on parsnip directly, but if we adapt just as_xgb_data the rest should work rather well

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, great! let's keep the development of changes to as_xgb_data() here for now though so we don't have to coordinate across multiple PRs.

value = list(
interface = "matrix",
protect = c("x", "y"),
func = c(pkg = "censored", fun = "xgb_train_censored"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xgb_train() is exported from parsnip, can you use that with objective = "survival:aft"? you should be able to set that in the defaults in the line below and also protect it in the line above.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that is genius, know I understand what protect does!, ok makes a lot o sense, just need to get as_xgb_data to work with cox and aft, I can pr this on this week no problem, it shouldn't mess up parsnip since I will do an if statement depending on the objective, and parsnip already has defaults for objectives

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of labels do you need to set for survival:cox? Different than for survival:aft?

Copy link
Author

@brunocarlin brunocarlin Mar 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, xgboost decided that for cox you only need 'label' and you set uncensored to - and censored to +, so 50+ becomes 50 and 100 becomes -100

@hfrick
Copy link
Member

hfrick commented Mar 22, 2022

Ah, I think I understand better now. The resulting hazards predictions for the Cox model would be a different prediction type in censored (type = "hazard"). So that's not a problem for censored but I'm not sure if stacking different prediction types is a good idea.

Using the survival:cox objective for xgboost could mean that's an argument but it could also mean, e.g., a different engine. Let's focus on survival:aft first.

Also just for clarification, censored only contains code for model engines (for fitting and predicting). Metrics and stacking and so on go into different places. We have plans to extend support for survival models across all necesasry parts of tidymodels but are currently focussed on the model definition/fitting/predicting part.

@brunocarlin
Copy link
Author

I like the idea of ensembling terrible base models into half-decent ones with a "stable" model at the end, but I can see that would be harder to implement, I agree with you, we can focus on implementing aft, I just don't want to restrict it too much and we should think about cox a little bit so we don't end up creating two functions, I also think we have the opportunity to set up
Interval-censored and left censored data on aft, but that would not work on cox.

I couldn't get protect to work on fit time because I won't know which prediction someone wants to do until they predict it, also I called the cox predictions as linear_preds's but I am pretty sure that is not right
@brunocarlin
Copy link
Author

I have updated the function to do aft and cox, I also check to see if it is the right objective

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(censored)
#> Carregando pacotes exigidos: survival
library(survival)
tidymodels_prefer()

data(cancer)

lung <- lung %>% drop_na()
lung_train <- lung[-c(1:5), ]
lung_test <- lung[1:5, ]


test <-
  boost_tree()|> set_engine('xgboost',objective  = 'survival:cox') |> set_mode('censored regression')

test |> 
  translate()
#> Boosted Tree Model Specification (censored regression)
#> 
#> Engine-Specific Arguments:
#>   objective = survival:cox
#> 
#> Computational engine: xgboost 
#> 
#> Model fit template:
#> censored::xgb_train_censored(x = missing_arg(), y = missing_arg(), 
#>     objective = "survival:cox", nthread = 1, verbose = 0)

set.seed(1)
bt_fit <- test %>% fit(Surv(time, status) ~ ., data = lung_train)
bt_fit
#> parsnip model object
#> 
#> ##### xgb.Booster
#> raw: 55.4 Kb 
#> call:
#>   xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0, 
#>     colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1, 
#>     subsample = 1, objective = "survival:cox"), data = x$data, 
#>     nrounds = 15, watchlist = x$watchlist, verbose = 0, nthread = 1)
#> params (as set within xgb.train):
#>   eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", objective = "survival:cox", nthread = "1", validate_parameters = "TRUE"
#> xgb.attributes:
#>   niter
#> callbacks:
#>   cb.evaluation.log()
#> # of features: 8 
#> niter: 15
#> nfeatures : 8 
#> evaluation_log:
#>     iter training_cox_nloglik
#>        1             3.967138
#>        2             3.835229
#> ---                          
#>       14             3.163877
#>       15             3.109603

predict(
  bt_fit, 
  lung_train,
  type = "linear_pred",
)
#> # A tibble: 162 x 1
#>    .pred_linear_pred
#>                <dbl>
#>  1             1.57 
#>  2             3.68 
#>  3             0.552
#>  4             0.304
#>  5             0.243
#>  6             0.159
#>  7            15.7  
#>  8             0.631
#>  9             2.89 
#> 10             0.517
#> # ... with 152 more rows

predict(
  bt_fit, 
  lung_train,
  type = "time",
)
#> Error:
#> ! The objective should be survival:aft not survival:cox

Created on 2022-03-24 by the reprex package (v2.0.1)

@topepo
Copy link
Member

topepo commented Mar 30, 2022

@brunocarlin You might want to look at the changes in the xib functions in parsnip that are being updated for case weights: tidymodels/parsnip@2100e8f

It will be helpful when we merge this if you use some of that restructuring.

@brunocarlin
Copy link
Author

@topepo got it, should I implement case weights for survival cases as well? Or just use the new naming and put an error if the user passes weights?

topepo added a commit to tidymodels/parsnip that referenced this pull request Apr 4, 2022
@topepo
Copy link
Member

topepo commented Apr 4, 2022

@brunocarlin I added a PR with the case weight changes to xib_train(). This will be better to work with than the previous version.

Let me know if you want a hand with this part; unfortunately Surv objects are pretty weird

library(survival)

cens_data <- with(aml[1:5,], Surv(time, status))

# Looks like a vector:
cens_data
#> [1]  9  13  13+ 18  23

# but is a hidden matrix: 
length(cens_data)    # <- as if it is a vector
#> [1] 5
dim(cens_data)       # <- this object is sus
#> [1] 5 2

# Be careful about subsetting
cens_data[1:3]
#> [1]  9  13  13+
cens_data[1:3, ]
#> [1]  9  13  13+
cens_data[1:3, 1]
#> [1]  9 13 13
cens_data[1:3, 2]   # <- these are different when you have interval censoring
#> [1] 1 1 0


# is it a surv object? 
inherits(cens_data, "Surv")
#> [1] TRUE

# is it numeric? 
inherits(cens_data, "numeric")
#> [1] FALSE
# but...
is.numeric(cens_data)
#> [1] TRUE
# ¯\_(ツ)_/¯

Created on 2022-04-04 by the reprex package (v2.0.1)

@topepo
Copy link
Member

topepo commented Apr 4, 2022

@brunocarlin In regard to case weights, the main changes are in tidymodels/parsnip#696 so no need to do anything else for that.

The won't work until tidymodels/parsnip#692 is merged (which will be a while)

made some changes in order to look more like the proposed new naming scheme
fixed some bugs on the first update
@brunocarlin
Copy link
Author

@topepo I think I got the gist of the naming scheme, I have opted into using train instead of training for some variables is that ok?

topepo added a commit to tidymodels/parsnip that referenced this pull request Apr 7, 2022
Copy link
Member

@hfrick hfrick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@brunocarlin If you want to do both model types then let's make them two separate engines: xgboost_aft and xgboost_cox. That fits much better with the design of parsnip where a combination of model spec (here: boost_tree()) and engine should only fit one model type. So let's protect the objective arg and remove the pre hooks in the prediction modules. Since each model (AFT and Cox) can only provide one type of prediction, they should both only have one prediction module each (the thing you set with set_pred()). Could you also formalize your testing in unit tests, please?

@brunocarlin
Copy link
Author

@hfrick Ok, I will try to work on it this week

I keep getting Error in inherits(new_data, "xgb.DMatrix") :
  argument "new_data" is missing, with no default
@brunocarlin
Copy link
Author

@topepo and @hfrick Hey guys, I am back, sorry for the long delay

I have tried splitting the engines but when importing the new xgb_predict I get the error

library(tidymodels)
library(censored)
#> Loading required package: survival
library(survival)
tidymodels_prefer()

data(cancer)

lung <- lung %>% drop_na()
lung_train <- lung[-c(1:5), ]
lung_test <- lung[1:5, ]


test <-
  boost_tree()|> set_engine('xgboost_aft') |> set_mode('censored regression')

test |>
  translate()
#> Boosted Tree Model Specification (censored regression)
#> 
#> Computational engine: xgboost_aft 
#> 
#> Model fit template:
#> censored::xgb_train_censored(x = missing_arg(), y = missing_arg(), 
#>     objective = missing_arg(), nthread = 1, verbose = 0, objective = "survival:aft")

set.seed(1)
bt_fit <- test %>% fit(Surv(time, status) ~ ., data = lung_train)
bt_fit
#> parsnip model object
#> 
#> ##### xgb.Booster
#> raw: 36.3 Kb 
#> call:
#>   xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0, 
#>     colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1, 
#>     subsample = 1, objective = "survival:aft"), data = x$data, 
#>     nrounds = 15, watchlist = x$watchlist, verbose = 0, nthread = 1)
#> params (as set within xgb.train):
#>   eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", objective = "survival:aft", nthread = "1", validate_parameters = "TRUE"
#> xgb.attributes:
#>   niter
#> callbacks:
#>   cb.evaluation.log()
#> # of features: 8 
#> niter: 15
#> nfeatures : 8 
#> evaluation_log:
#>     iter training_aft_nloglik
#>        1            14.698676
#>        2             9.883017
#> ---                          
#>       14             4.564666
#>       15             4.553461

predict(
  bt_fit,
  lung_test,
  type = "linear_pred",
)
#> Error in `check_spec_pred_type()`:
#> ! No linear_pred prediction method available for this model.
#> • Value for `type` should be one of: 'time'

#> Backtrace:
#>     ▆
#>  1. ├─stats::predict(bt_fit, lung_test, type = "linear_pred", )
#>  2. └─parsnip::predict.model_fit(...)
#>  3.   ├─parsnip::predict_linear_pred(...)
#>  4.   └─parsnip::predict_linear_pred.model_fit(...)
#>  5.     └─parsnip:::check_spec_pred_type(object, "linear_pred")
#>  6.       └─rlang::abort(...)

predict(
  bt_fit,
  lung_test,
  type = 'time'
)
#> Error in inherits(new_data, "xgb.DMatrix"): argument "new_data" is missing, with no default

Created on 2022-08-27 with reprex v2.0.2

@hfrick
Copy link
Member

hfrick commented Oct 5, 2022

Hey @brunocarlin! Apologies that it has taken a while to get back to you! I've been meaning to return to this PR but haven't found the bandwidth for it. We have now decided that we'll focus on advancing support for censored regression in other places of tidymodels. This means that I'll take a deeper look into debugging the error in this branch later, most likely when I return to adding additional engines to censored. I appreciate you getting this started and the various pointers to the relevant docs of xgboost!

@hfrick
Copy link
Member

hfrick commented Oct 5, 2022

note to self: add any additional learnings from this to the checklist for adding an engine to what is currently here: tidymodels/tidymodels#97

@brunocarlin
Copy link
Author

brunocarlin commented Oct 5, 2022

@hfrick No problems, happy to help on this as I also learned a lot

@brunocarlin brunocarlin changed the title xgboost:survival:aft Xgboost implement Aft and Cox with weights, untested need support for testing Jul 4, 2024
@brunocarlin
Copy link
Author

After all these years lol, but anyways @hfrick this should cover the basic functionality for aft and cox models that the xgboost package provides out of the box, I have just basically wrapped the data back and forth with the right format on the dread Xgboost Matrix, it should work out of the box according to my own testing, but I haven't stress tested arguments, although I don't see why it shouldn't work with weights and arguments. Also I used the version I had on hand for censored, maybe something else changed for .3

library(tidymodels)
library(censored)
#> Carregando pacotes exigidos: survival
library(survival)
tidymodels_prefer()

data(cancer)

lung <- lung %>% drop_na()
lung_train <- lung[-c(1:5), ]
lung_test <- lung[1:5, ]


# Cox

test <-
  boost_tree()|> set_engine('xgboost_cox') |> set_mode('censored regression')

test |>
  translate()
#> Boosted Tree Model Specification (censored regression)
#> 
#> Computational engine: xgboost_cox 
#> 
#> Model fit template:
#> censored::xgb_train_censored(x = missing_arg(), y = missing_arg(), 
#>     objective = missing_arg(), nthread = 1, verbose = 0, objective = "survival:cox")
set.seed(1)
bt_fit <- test %>% fit(Surv(time, status) ~ ., data = lung_train)
bt_fit
#> parsnip model object
#> 
#> ##### xgb.Booster
#> raw: 40.5 Kb 
#> call:
#>   xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0, 
#>     colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1, 
#>     subsample = 1, objective = "survival:cox"), data = x$data, 
#>     nrounds = 15, watchlist = x$watchlist, verbose = 0, nthread = 1)
#> params (as set within xgb.train):
#>   eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", objective = "survival:cox", nthread = "1", validate_parameters = "TRUE"
#> xgb.attributes:
#>   niter
#> callbacks:
#>   cb.evaluation.log()
#> # of features: 8 
#> niter: 15
#> nfeatures : 8 
#> evaluation_log:
#>      iter training_cox_nloglik
#>     <num>                <num>
#>         1             3.967019
#>         2             3.840237
#> ---                           
#>        14             3.095573
#>        15             3.054746
predict(
  bt_fit,
  lung_test,
  type = "linear_pred",
)
#> # A tibble: 5 × 1
#>   .pred_linear_pred
#>               <dbl>
#> 1             0.351
#> 2             4.41 
#> 3             2.23 
#> 4             4.50 
#> 5             2.36
predict(
  bt_fit,
  lung_test,
  type = 'time'
)
#> Error in `check_spec_pred_type()`:
#> ! No time prediction method available for this model.
#> • Value for `type` should be one of: 'linear_pred'
# Aft

test <-
  boost_tree()|> set_engine('xgboost_aft') |> set_mode('censored regression')

test |>
  translate()
#> Boosted Tree Model Specification (censored regression)
#> 
#> Computational engine: xgboost_aft 
#> 
#> Model fit template:
#> censored::xgb_train_censored(x = missing_arg(), y = missing_arg(), 
#>     objective = missing_arg(), nthread = 1, verbose = 0, objective = "survival:aft")
set.seed(1)
bt_fit <- test %>% fit(Surv(time, status) ~ ., data = lung_train)
bt_fit
#> parsnip model object
#> 
#> ##### xgb.Booster
#> raw: 36.3 Kb 
#> call:
#>   xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0, 
#>     colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1, 
#>     subsample = 1, objective = "survival:aft"), data = x$data, 
#>     nrounds = 15, watchlist = x$watchlist, verbose = 0, nthread = 1)
#> params (as set within xgb.train):
#>   eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "1", objective = "survival:aft", nthread = "1", validate_parameters = "TRUE"
#> xgb.attributes:
#>   niter
#> callbacks:
#>   cb.evaluation.log()
#> # of features: 8 
#> niter: 15
#> nfeatures : 8 
#> evaluation_log:
#>      iter training_aft_nloglik
#>     <num>                <num>
#>         1            14.698676
#>         2             9.883017
#> ---                           
#>        14             4.564666
#>        15             4.553461
predict(
  bt_fit,
  lung_test,
  type = "linear_pred",
)
#> Error in `check_spec_pred_type()`:
#> ! No linear_pred prediction method available for this model.
#> • Value for `type` should be one of: 'time'
predict(
  bt_fit,
  lung_test,
  type = 'time'
)
#> # A tibble: 5 × 1
#>   .pred_time
#>        <dbl>
#> 1      420. 
#> 2      239. 
#> 3      120. 
#> 4       78.7
#> 5      350.

Created on 2024-07-04 with reprex v2.1.0

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.4.1 (2024-06-14 ucrt)
#>  os       Windows 11 x64 (build 22631)
#>  system   x86_64, mingw32
#>  ui       RTerm
#>  language (EN)
#>  collate  Portuguese_Brazil.utf8
#>  ctype    Portuguese_Brazil.utf8
#>  tz       America/Sao_Paulo
#>  date     2024-07-04
#>  pandoc   3.1.11 @ C:/Program Files/RStudio/resources/app/bin/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  backports      1.5.0      2024-05-23 [1] CRAN (R 4.4.0)
#>  broom        * 1.0.6      2024-05-17 [1] CRAN (R 4.4.1)
#>  cachem         1.1.0      2024-05-16 [1] CRAN (R 4.4.1)
#>  censored     * 0.0.0.9000 2024-07-04 [1] local
#>  class          7.3-22     2023-05-03 [2] CRAN (R 4.4.1)
#>  cli            3.6.3      2024-06-21 [1] CRAN (R 4.4.1)
#>  codetools      0.2-20     2024-03-31 [2] CRAN (R 4.4.1)
#>  colorspace     2.1-0      2023-01-23 [1] CRAN (R 4.4.1)
#>  conflicted     1.2.0      2023-02-01 [1] CRAN (R 4.4.1)
#>  data.table     1.15.4     2024-03-30 [1] CRAN (R 4.4.1)
#>  dials        * 1.2.1      2024-02-22 [1] CRAN (R 4.4.1)
#>  DiceDesign     1.10       2023-12-07 [1] CRAN (R 4.4.1)
#>  digest         0.6.36     2024-06-23 [1] CRAN (R 4.4.1)
#>  dplyr        * 1.1.4      2023-11-17 [1] CRAN (R 4.4.1)
#>  evaluate       0.24.0     2024-06-10 [1] CRAN (R 4.4.1)
#>  fansi          1.0.6      2023-12-08 [1] CRAN (R 4.4.1)
#>  fastmap        1.2.0      2024-05-15 [1] CRAN (R 4.4.1)
#>  foreach        1.5.2      2022-02-02 [1] CRAN (R 4.4.1)
#>  fs             1.6.4      2024-04-25 [1] CRAN (R 4.4.1)
#>  furrr          0.3.1      2022-08-15 [1] CRAN (R 4.4.1)
#>  future         1.33.2     2024-03-26 [1] CRAN (R 4.4.1)
#>  future.apply   1.11.2     2024-03-28 [1] CRAN (R 4.4.1)
#>  generics       0.1.3      2022-07-05 [1] CRAN (R 4.4.1)
#>  ggplot2      * 3.5.1      2024-04-23 [1] CRAN (R 4.4.1)
#>  globals        0.16.3     2024-03-08 [1] CRAN (R 4.4.0)
#>  glue           1.7.0      2024-01-09 [1] CRAN (R 4.4.1)
#>  gower          1.0.1      2022-12-22 [1] CRAN (R 4.4.0)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.4.1)
#>  gtable         0.3.5      2024-04-22 [1] CRAN (R 4.4.1)
#>  hardhat        1.4.0      2024-06-02 [1] CRAN (R 4.4.1)
#>  htmltools      0.5.8.1    2024-04-04 [1] CRAN (R 4.4.1)
#>  infer        * 1.0.7      2024-03-25 [1] CRAN (R 4.4.1)
#>  ipred          0.9-14     2023-03-09 [1] CRAN (R 4.4.1)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.4.1)
#>  jsonlite       1.8.8      2023-12-04 [1] CRAN (R 4.4.1)
#>  knitr          1.47       2024-05-29 [1] CRAN (R 4.4.1)
#>  lattice        0.22-6     2024-03-20 [2] CRAN (R 4.4.1)
#>  lava           1.8.0      2024-03-05 [1] CRAN (R 4.4.1)
#>  lhs            1.2.0      2024-06-30 [1] CRAN (R 4.4.1)
#>  lifecycle      1.0.4      2023-11-07 [1] CRAN (R 4.4.1)
#>  listenv        0.9.1      2024-01-29 [1] CRAN (R 4.4.1)
#>  lubridate      1.9.3      2023-09-27 [1] CRAN (R 4.4.1)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.4.1)
#>  MASS           7.3-60.2   2024-04-26 [2] CRAN (R 4.4.1)
#>  Matrix         1.7-0      2024-04-26 [2] CRAN (R 4.4.1)
#>  memoise        2.0.1      2021-11-26 [1] CRAN (R 4.4.1)
#>  modeldata    * 1.4.0      2024-06-19 [1] CRAN (R 4.4.1)
#>  munsell        0.5.1      2024-04-01 [1] CRAN (R 4.4.1)
#>  nnet           7.3-19     2023-05-03 [2] CRAN (R 4.4.1)
#>  parallelly     1.37.1     2024-02-29 [1] CRAN (R 4.4.0)
#>  parsnip      * 1.2.1      2024-03-22 [1] CRAN (R 4.4.1)
#>  pillar         1.9.0      2023-03-22 [1] CRAN (R 4.4.1)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.4.1)
#>  prodlim        2024.06.25 2024-06-24 [1] CRAN (R 4.4.1)
#>  purrr        * 1.0.2      2023-08-10 [1] CRAN (R 4.4.1)
#>  R.cache        0.16.0     2022-07-21 [1] CRAN (R 4.4.1)
#>  R.methodsS3    1.8.2      2022-06-13 [1] CRAN (R 4.4.0)
#>  R.oo           1.26.0     2024-01-24 [1] CRAN (R 4.4.0)
#>  R.utils        2.12.3     2023-11-18 [1] CRAN (R 4.4.1)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.4.1)
#>  Rcpp           1.0.12     2024-01-09 [1] CRAN (R 4.4.1)
#>  recipes      * 1.0.10     2024-02-18 [1] CRAN (R 4.4.1)
#>  reprex         2.1.0      2024-01-11 [1] CRAN (R 4.4.1)
#>  rlang          1.1.4      2024-06-04 [1] CRAN (R 4.4.1)
#>  rmarkdown      2.27       2024-05-17 [1] CRAN (R 4.4.1)
#>  rpart          4.1.23     2023-12-05 [2] CRAN (R 4.4.1)
#>  rsample      * 1.2.1      2024-03-25 [1] CRAN (R 4.4.1)
#>  rstudioapi     0.16.0     2024-03-24 [1] CRAN (R 4.4.1)
#>  scales       * 1.3.0      2023-11-28 [1] CRAN (R 4.4.1)
#>  sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.4.1)
#>  styler         1.10.3     2024-04-07 [1] CRAN (R 4.4.1)
#>  survival     * 3.7-0      2024-06-05 [1] CRAN (R 4.4.1)
#>  tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.4.1)
#>  tidymodels   * 1.2.0      2024-03-25 [1] CRAN (R 4.4.1)
#>  tidyr        * 1.3.1      2024-01-24 [1] CRAN (R 4.4.1)
#>  tidyselect     1.2.1      2024-03-11 [1] CRAN (R 4.4.1)
#>  timechange     0.3.0      2024-01-18 [1] CRAN (R 4.4.1)
#>  timeDate       4032.109   2023-12-14 [1] CRAN (R 4.4.1)
#>  tune         * 1.2.1      2024-04-18 [1] CRAN (R 4.4.1)
#>  utf8           1.2.4      2023-10-22 [1] CRAN (R 4.4.1)
#>  vctrs          0.6.5      2023-12-01 [1] CRAN (R 4.4.1)
#>  withr          3.0.0      2024-01-16 [1] CRAN (R 4.4.1)
#>  workflows    * 1.1.4      2024-02-19 [1] CRAN (R 4.4.1)
#>  workflowsets * 1.1.0      2024-03-21 [1] CRAN (R 4.4.1)
#>  xfun           0.45       2024-06-16 [1] CRAN (R 4.4.1)
#>  xgboost        1.7.7.1    2024-01-25 [1] CRAN (R 4.4.1)
#>  yaml           2.3.8      2023-12-11 [1] CRAN (R 4.4.1)
#>  yardstick    * 1.3.1      2024-03-21 [1] CRAN (R 4.4.1)
#> 
#>  [1] C:/Users/bruno/AppData/Local/R/win-library/4.4
#>  [2] C:/Program Files/R/R-4.4.1/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants