From 4698dcdb5238748951a087a5b26309c6b2826cc0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 27 Nov 2022 20:28:36 +0200 Subject: [PATCH] whisper : add mechanism for aborting the whisper_full() computation --- examples/main/main.cpp | 13 +++++++++++++ whisper.cpp | 13 +++++++++++++ whisper.h | 11 +++++++++++ 3 files changed, 37 insertions(+) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 569404caa49..465d43fb079 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -607,6 +607,19 @@ int main(int argc, char ** argv) { wparams.new_segment_callback_user_data = &user_data; } + // example for abort mechanism + // in this example, we do not abort the processing, but we could if the flag is set to true + // the callback is called before every encoder run - if it returns false, the processing is aborted + { + static bool is_aborted = false; // NOTE: this should be atomic to avoid data race + + wparams.encoder_begin_callback = [](struct whisper_context * ctx, void * user_data) { + bool is_aborted = *(bool*)user_data; + return !is_aborted; + }; + wparams.encoder_begin_callback_user_data = &is_aborted; + } + if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); return 10; diff --git a/whisper.cpp b/whisper.cpp index 2daf41165d7..fbcb5d14c03 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2451,6 +2451,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.new_segment_callback =*/ nullptr, /*.new_segment_callback_user_data =*/ nullptr, + + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, }; } break; case WHISPER_SAMPLING_BEAM_SEARCH: @@ -2497,6 +2500,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.new_segment_callback =*/ nullptr, /*.new_segment_callback_user_data =*/ nullptr, + + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, }; } break; } @@ -2659,6 +2665,13 @@ int whisper_full( break; } + if (params.encoder_begin_callback) { + if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) { + fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__); + break; + } + } + // encode audio features starting at offset seek if (whisper_encode(ctx, seek, params.n_threads) != 0) { fprintf(stderr, "%s: failed to encode\n", __func__); diff --git a/whisper.h b/whisper.h index 4b5fbccd4e3..156edbbf454 100644 --- a/whisper.h +++ b/whisper.h @@ -185,6 +185,14 @@ extern "C" { // Use the whisper_full_...() functions to obtain the text segments typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data); + // Encoder begin callback + // If not NULL, called before the encoder starts + // If it returns false, the computation is aborted + typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data); + + // Parameters for the whisper_full() function + // If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp: + // whisper_full_default_params() struct whisper_full_params { enum whisper_sampling_strategy strategy; @@ -231,6 +239,9 @@ extern "C" { whisper_new_segment_callback new_segment_callback; void * new_segment_callback_user_data; + + whisper_encoder_begin_callback encoder_begin_callback; + void * encoder_begin_callback_user_data; }; WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);