From 624d44d7f56571326caa61cbc37c74e0575e11ce Mon Sep 17 00:00:00 2001 From: Matt Dancho Date: Mon, 17 Oct 2022 13:40:56 -0400 Subject: [PATCH] #22 recursive chunks --- R/modeltime_forecast.R | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/R/modeltime_forecast.R b/R/modeltime_forecast.R index 2bc42a8..905bd7a 100644 --- a/R/modeltime_forecast.R +++ b/R/modeltime_forecast.R @@ -402,9 +402,14 @@ mdl_time_forecast_recursive_ensemble <- function(object, calibration_data, f = (seq_len(nrow(new_data)) - 1) %/% chunk_size) # LOOP LOGIC ---- + + # print(new_data) + .first_slice <- new_data %>% dplyr::slice_head(n = chunk_size) + # print(.first_slice) + .forecasts <- modeltime::mdl_time_forecast( object, new_data = .first_slice, @@ -418,8 +423,12 @@ mdl_time_forecast_recursive_ensemble <- function(object, calibration_data, .forecast_from_model <- .forecasts %>% dplyr::filter(.key == "prediction") + # print(.forecast_from_model) + new_data[idx_sets[[1]], y_var] <- .forecast_from_model$.value + # print(new_data) + .temp_new_data <- dplyr::bind_rows( train_tail, new_data @@ -433,6 +442,8 @@ mdl_time_forecast_recursive_ensemble <- function(object, calibration_data, transform_window_start <- min(idx_sets[[i]]) transform_window_end <- max(idx_sets[[i]]) + n_train_tail + # print(.temp_new_data[transform_window_start:transform_window_end,]) + # .nth_slice <- .transform(.temp_new_data, nrow(new_data), i) .nth_slice <- .transform(.temp_new_data[transform_window_start:transform_window_end,], length(idx_sets[[i]])) @@ -448,20 +459,22 @@ mdl_time_forecast_recursive_ensemble <- function(object, calibration_data, ... ) - # print(.nth_forecast) - .nth_forecast_from_model <- .nth_forecast %>% dplyr::filter(.key == "prediction") %>% .[1,] + # print(.nth_forecast_from_model) + .forecasts <- dplyr::bind_rows( .forecasts, .nth_forecast_from_model ) - new_data[idx_sets[[i]], y_var] <- .nth_forecast_from_model$.value + .temp_new_data[idx_sets[[i]] + n_train_tail, y_var] <- .nth_forecast_from_model$.value } } + # print(.forecasts) + return(.forecasts) } @@ -514,6 +527,14 @@ mdl_time_forecast_recursive_ensemble_panel <- function(object, calibration_data, dplyr::mutate(rowid.. = dplyr::row_number()) %>% dplyr::ungroup() + # Fix - When ID is dummied + if (!is.null(object$spec$remove_id)) { + if (object$spec$remove_id) { + .first_slice <- .first_slice %>% + dplyr::select(-(!! .id)) + } + } + if ("rowid.." %in% names(.first_slice)) { .first_slice <- .first_slice %>% dplyr::select(-rowid..) } @@ -532,7 +553,7 @@ mdl_time_forecast_recursive_ensemble_panel <- function(object, calibration_data, .forecast_from_model <- .forecasts %>% dplyr::filter(.key == "prediction") - .preds[.preds$rowid.. %in% idx_sets[[1]], 2] <- new_data[.preds$rowid.. %in% idx_sets[[1]], y_var] <- .forecast_from_model$.value + new_data[.preds$rowid.. %in% idx_sets[[1]], y_var] <- .forecast_from_model$.value .groups <- new_data %>% dplyr::group_by(!! .id) %>% @@ -556,7 +577,6 @@ mdl_time_forecast_recursive_ensemble_panel <- function(object, calibration_data, transform_window_end <- max(idx_sets[[i]]) + n_train_tail - .nth_slice <- .transform(.temp_new_data %>% dplyr::group_by(!! .id) %>% dplyr::slice(transform_window_start:transform_window_end), @@ -594,8 +614,7 @@ mdl_time_forecast_recursive_ensemble_panel <- function(object, calibration_data, .forecasts, .nth_forecast_from_model ) - - .preds[.preds$rowid.. %in% idx_sets[[i]], 2] <- .temp_new_data[.temp_new_data$rowid.. %in% idx_sets[[i]], y_var] <- .nth_forecast_from_model$.value + .temp_new_data[.temp_new_data$rowid.. %in% idx_sets[[i]], y_var] <- .nth_forecast_from_model$.value } }