From 0f619b52ce3e7652fa78b792ee5b23584210562a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 25 Nov 2022 22:08:58 +0200 Subject: [PATCH] main : add stereo-channel-based diarization (#64) Not tested - I don't have stereo dialog audio --- examples/main/main.cpp | 81 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 71 insertions(+), 10 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 2b9f2e1745d..569404caa49 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -36,6 +36,10 @@ std::string to_timestamp(int64_t t, bool comma = false) { return std::string(buf); } +int timestamp_to_sample(int64_t t, int n_samples) { + return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); +} + // helper function to replace substrings void replace_all(std::string & s, const std::string & search, const std::string & replace) { for (size_t pos = 0; ; pos += replace.length()) { @@ -60,6 +64,7 @@ struct whisper_params { bool speed_up = false; bool translate = false; + bool diarize = false; bool output_txt = false; bool output_vtt = false; bool output_srt = false; @@ -99,6 +104,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } @@ -135,6 +141,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); @@ -148,8 +155,15 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, "\n"); } +struct whisper_print_user_data { + const whisper_params * params; + + const std::vector> * pcmf32s; +}; + void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) { - const whisper_params & params = *(whisper_params *) user_data; + const auto & params = *((whisper_print_user_data *) user_data)->params; + const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; const int n_segments = whisper_full_n_segments(ctx); @@ -186,6 +200,33 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + std::string speaker = ""; + + if (params.diarize && pcmf32s.size() == 2) { + const int64_t n_samples = pcmf32s[0].size(); + + const int64_t is0 = timestamp_to_sample(t0, n_samples); + const int64_t is1 = timestamp_to_sample(t1, n_samples); + + double energy0 = 0.0f; + double energy1 = 0.0f; + + for (int64_t j = is0; j < is1; j++) { + energy0 += fabs(pcmf32s[0][j]); + energy1 += fabs(pcmf32s[1][j]); + } + + if (energy0 > 1.1*energy1) { + speaker = "(speaker 0)"; + } else if (energy1 > 1.1*energy0) { + speaker = "(speaker 1)"; + } else { + speaker = "(speaker ?)"; + } + + //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str()); + } + if (params.print_colors) { printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { @@ -201,13 +242,13 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); - printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); + printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); } printf("\n"); } else { const char * text = whisper_full_get_segment_text(ctx, i); - printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); + printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text); } } } @@ -235,7 +276,7 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) { std::ofstream fout(fname); if (!fout.is_open()) { fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); - return 9; + return false; } fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); @@ -425,6 +466,7 @@ int main(int argc, char ** argv) { const auto fname_inp = params.fname_inp[f]; std::vector pcmf32; // mono-channel F32 PCM + std::vector> pcmf32s; // stereo-channel F32 PCM // WAV input { @@ -453,22 +495,27 @@ int main(int argc, char ** argv) { } else if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) { fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); - return 4; + return 5; } if (wav.channels != 1 && wav.channels != 2) { fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str()); - return 5; + return 6; + } + + if (params.diarize && wav.channels != 2 && params.no_timestamps == false) { + fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", argv[0], fname_inp.c_str()); + return 6; } if (wav.sampleRate != WHISPER_SAMPLE_RATE) { fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str()); - return 6; + return 8; } if (wav.bitsPerSample != 16) { fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str()); - return 7; + return 9; } const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8); @@ -489,6 +536,18 @@ int main(int argc, char ** argv) { pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; } } + + if (params.diarize) { + // convert to stereo, float + pcmf32s.resize(2); + + pcmf32s[0].resize(n); + pcmf32s[1].resize(n); + for (int i = 0; i < n; i++) { + pcmf32s[0][i] = float(pcm16[2*i])/32768.0f; + pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f; + } + } } // print system information @@ -540,15 +599,17 @@ int main(int argc, char ** argv) { wparams.speed_up = params.speed_up; + whisper_print_user_data user_data = { ¶ms, &pcmf32s }; + // this callback is called on each new segment if (!wparams.print_realtime) { wparams.new_segment_callback = whisper_print_segment_callback; - wparams.new_segment_callback_user_data = ¶ms; + wparams.new_segment_callback_user_data = &user_data; } if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); - return 8; + return 10; } }