Skip to content

Commit

Permalink
first working version
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Dec 17, 2024
1 parent 4666f37 commit 6508ecd
Show file tree
Hide file tree
Showing 13 changed files with 222 additions and 18 deletions.
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ list(APPEND sources

if(SHERPA_ONNX_ENABLE_TTS)
list(APPEND sources
hifigan-vocoder.cc
jieba-lexicon.cc
lexicon.cc
melo-tts-lexicon.cc
Expand Down
107 changes: 107 additions & 0 deletions sherpa-onnx/csrc/hifigan-vocoder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// sherpa-onnx/csrc/hifigan-vocoder.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/hifigan-vocoder.h"

#include <string>
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"

namespace sherpa_onnx {

class HifiganVocoder::Impl {
public:
explicit Impl(int32_t num_threads, const std::string &provider,
const std::string &model)
: env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(num_threads, provider)),
allocator_{} {
auto buf = ReadFile(model);
Init(buf.data(), buf.size());
}

template <typename Manager>
explicit Impl(Manager *mgr, int32_t num_threads, const std::string &provider,
const std::string &model)
: env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(num_threads, provider)),
allocator_{} {
auto buf = ReadFile(mgr, model);
Init(buf.data(), buf.size());
}

Ort::Value Run(Ort::Value mel) const {
auto out = sess_->Run({}, input_names_ptr_.data(), &mel, 1,
output_names_ptr_.data(), output_names_ptr_.size());

return std::move(out[0]);
}

private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);

GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);

GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
}

private:
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;

std::unique_ptr<Ort::Session> sess_;

std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;

std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
};

HifiganVocoder::HifiganVocoder(int32_t num_threads, const std::string &provider,
const std::string &model)
: impl_(std::make_unique<Impl>(num_threads, provider, model)) {}

template <typename Manager>
HifiganVocoder::HifiganVocoder(Manager *mgr, int32_t num_threads,
const std::string &provider,
const std::string &model)
: impl_(std::make_unique<Impl>(mgr, num_threads, provider, model)) {}

HifiganVocoder::~HifiganVocoder() = default;

Ort::Value HifiganVocoder::Run(Ort::Value mel) const {
return impl_->Run(std::move(mel));
}

#if __ANDROID_API__ >= 9
template HifiganVocoder::HifiganVocoder(AAssetManager *mgr, int32_t num_threads,
const std::string &provider,
const std::string &model);
#endif

#if __OHOS__
template HifiganVocoder::HifiganVocoder(NativeResourceManager *mgr,
int32_t num_threads,
const std::string &provider,
const std::string &model);
#endif

} // namespace sherpa_onnx
38 changes: 38 additions & 0 deletions sherpa-onnx/csrc/hifigan-vocoder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// sherpa-onnx/csrc/hifigan-vocoder.h
//
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_HIFIGAN_VOCODER_H_
#define SHERPA_ONNX_CSRC_HIFIGAN_VOCODER_H_

#include <memory>
#include <string>

#include "onnxruntime_cxx_api.h" // NOLINT

namespace sherpa_onnx {

class HifiganVocoder {
public:
~HifiganVocoder();

HifiganVocoder(int32_t num_threads, const std::string &provider,
const std::string &model);

template <typename Manager>
HifiganVocoder(Manager *mgr, int32_t num_threads, const std::string &provider,
const std::string &model);

/** @param mel A float32 tensor of shape (batch_size, feat_dim, num_frames).
* @return Return a float32 tensor of shape (batch_size, num_samples).
*/
Ort::Value Run(Ort::Value mel) const;

private:
class Impl;
std::unique_ptr<Impl> impl_;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_HIFIGAN_VOCODER_H_
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/jieba-lexicon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class JiebaLexicon::Impl {
}

this_sentence.insert(this_sentence.end(), ids.begin(), ids.end());
this_sentence.push_back(blank);
// this_sentence.push_back(blank);

if (w == "" || w == "" || w == "" || w == "") {
ans.emplace_back(std::move(this_sentence));
Expand Down
6 changes: 3 additions & 3 deletions sherpa-onnx/csrc/offline-tts-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@

namespace sherpa_onnx {

std::vector<int64_t> OfflineTtsImpl::AddBlank(
const std::vector<int64_t> &x) const {
std::vector<int64_t> OfflineTtsImpl::AddBlank(const std::vector<int64_t> &x,
int32_t blank_id /*= 0*/) const {
// we assume the blank ID is 0
std::vector<int64_t> buffer(x.size() * 2 + 1);
std::vector<int64_t> buffer(x.size() * 2 + 1, blank_id);
int32_t i = 1;
for (auto k : x) {
buffer[i] = k;
Expand Down
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/offline-tts-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class OfflineTtsImpl {
// If it supports only a single speaker, then it return 0 or 1.
virtual int32_t NumSpeakers() const = 0;

std::vector<int64_t> AddBlank(const std::vector<int64_t> &x) const;
std::vector<int64_t> AddBlank(const std::vector<int64_t> &x,
int32_t blank_id = 0) const;
};

} // namespace sherpa_onnx
Expand Down
67 changes: 58 additions & 9 deletions sherpa-onnx/csrc/offline-tts-matcha-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "fst/extensions/far/far.h"
#include "kaldifst/csrc/kaldi-fst-io.h"
#include "kaldifst/csrc/text-normalizer.h"
#include "sherpa-onnx/csrc/hifigan-vocoder.h"
#include "sherpa-onnx/csrc/jieba-lexicon.h"
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h"
Expand All @@ -31,7 +32,10 @@ class OfflineTtsMatchaImpl : public OfflineTtsImpl {
public:
explicit OfflineTtsMatchaImpl(const OfflineTtsConfig &config)
: config_(config),
model_(std::make_unique<OfflineTtsMatchaModel>(config.model)) {
model_(std::make_unique<OfflineTtsMatchaModel>(config.model)),
vocoder_(std::make_unique<HifiganVocoder>(
config.model.num_threads, config.model.provider,
config.model.matcha.vocoder)) {
InitFrontend();

if (!config.rule_fsts.empty()) {
Expand Down Expand Up @@ -87,7 +91,10 @@ class OfflineTtsMatchaImpl : public OfflineTtsImpl {
template <typename Manager>
OfflineTtsMatchaImpl(Manager *mgr, const OfflineTtsConfig &config)
: config_(config),
model_(std::make_unique<OfflineTtsMatchaModel>(mgr, config.model)) {
model_(std::make_unique<OfflineTtsMatchaModel>(mgr, config.model)),
vocoder_(std::make_unique<HifiganVocoder>(
mgr, config.model.num_threads, config.model.provider,
config.model.matcha.vocoder)) {
InitFrontend(mgr);

if (!config.rule_fsts.empty()) {
Expand Down Expand Up @@ -228,10 +235,54 @@ class OfflineTtsMatchaImpl : public OfflineTtsImpl {
x.push_back(std::move(i.tokens));
}

if (config_.model.debug) {
std::ostringstream os;
os << "\n";
for (const auto &k : x) {
for (int32_t i : k) {
os << i << " ";
}
os << "\n";
}
os << "\n";
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}

if (meta_data.add_blank) {
for (auto &k : x) {
k = AddBlank(k);
}

if (config_.model.debug) {
std::ostringstream os;
os << "\n";
for (const auto &k : x) {
for (int32_t i : k) {
os << i << " ";
}
os << "\n";
}
os << "\n";
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}

for (auto &k : x) {
// TODO(fangjun): Fix it!
k = AddBlank(k, 62);
}
}

if (config_.model.debug) {
std::ostringstream os;
os << "\n";
for (const auto &k : x) {
for (int32_t i : k) {
os << i << " ";
}
os << "\n";
}
os << "\n";
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}

int32_t x_size = static_cast<int32_t>(x.size());
Expand Down Expand Up @@ -345,21 +396,18 @@ class OfflineTtsMatchaImpl : public OfflineTtsImpl {
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());

Ort::Value mel = model_->Run(std::move(x_tensor), sid, speed);
Ort::Value audio = vocoder_->Run(std::move(mel));

std::vector<int64_t> mel_shape = mel.GetTensorTypeAndShapeInfo().GetShape();
SHERPA_ONNX_LOGE("mel shape size: %d", (int)mel_shape.size());
for (int32_t i : mel_shape) {
SHERPA_ONNX_LOGE(" %d", i);
}
return {};
std::vector<int64_t> audio_shape =
audio.GetTensorTypeAndShapeInfo().GetShape();

int64_t total = 1;
// The output shape may be (1, 1, total) or (1, total) or (total,)
for (auto i : audio_shape) {
total *= i;
}

const float *p = mel.GetTensorData<float>();
const float *p = audio.GetTensorData<float>();

GeneratedAudio ans;
ans.sample_rate = model_->GetMetaData().sample_rate;
Expand All @@ -370,6 +418,7 @@ class OfflineTtsMatchaImpl : public OfflineTtsImpl {
private:
OfflineTtsConfig config_;
std::unique_ptr<OfflineTtsMatchaModel> model_;
std::unique_ptr<HifiganVocoder> vocoder_;
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> tn_list_;
std::unique_ptr<OfflineTtsFrontend> frontend_;
};
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/offline-tts-matcha-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ const OfflineTtsMatchaModelMetaData &OfflineTtsMatchaModel::GetMetaData()
}

Ort::Value OfflineTtsMatchaModel::Run(Ort::Value x, int64_t sid /*= 0*/,
float speed /*= 1.0*/) {
float speed /*= 1.0*/) const {
return impl_->Run(std::move(x), sid, speed);
}

Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/offline-tts-matcha-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class OfflineTtsMatchaModel {

// Return a float32 tensor containing the mel
// of shape (batch_size, mel_dim, num_frames)
Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0);
Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0) const;

const OfflineTtsMatchaModelMetaData &GetMetaData() const;

Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/offline-tts-vits-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/,

Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, Ort::Value tones,
int64_t sid /*= 0*/,
float speed /*= 1.0*/) {
float speed /*= 1.0*/) const {
return impl_->Run(std::move(x), std::move(tones), sid, speed);
}

Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/offline-tts-vits-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class OfflineTtsVitsModel {

// This is for MeloTTS
Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid = 0,
float speed = 1.0);
float speed = 1.0) const;

const OfflineTtsVitsModelMetaData &GetMetaData() const;

Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/csrc/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,4 +273,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) {
return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider);
}

Ort::SessionOptions GetSessionOptions(int32_t num_threads,
const std::string &provider_str) {
return GetSessionOptionsImpl(num_threads, provider_str);
}

} // namespace sherpa_onnx
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config);
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
const std::string &model_type);

Ort::SessionOptions GetSessionOptions(int32_t num_threads,
const std::string &provider_str);

template <typename T>
Ort::SessionOptions GetSessionOptions(const T &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
Expand Down

0 comments on commit 6508ecd

Please sign in to comment.