Skip to content

Commit

Permalink
fix: replace import tf.keras to keras, update tiny rnnt model result
Browse files Browse the repository at this point in the history
  • Loading branch information
nglehuy committed May 25, 2024
1 parent a4d411d commit 63a400c
Show file tree
Hide file tree
Showing 54 changed files with 1,013 additions and 603 deletions.
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,12 @@ repos:
stages: [pre-commit]
fail_fast: true
verbose: true
- id: pylint-check
name: pylint-check
entry: pylint --rcfile=.pylintrc -rn -sn
language: system
types: [python]
stages: [pre-commit]
fail_fast: true
require_serial: true
verbose: true
6 changes: 5 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ disable=too-few-public-methods,
consider-using-enumerate,
too-many-statements,
assignment-from-none,
eval-used
eval-used,
duplicate-code,
redefined-outer-name,
consider-using-f-string,
fixme,

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
14 changes: 12 additions & 2 deletions examples/inferences/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
import os

import tensorflow as tf
import keras

from tensorflow_asr import schemas, tokenizers
from tensorflow_asr.models import base_model
from tensorflow_asr.configs import Config
from tensorflow_asr.utils import cli_util, data_util, env_util, file_util

Expand All @@ -35,7 +37,7 @@ def main(
config = Config(config_path, training=False, repodir=repodir)
tokenizer = tokenizers.get(config)

model: tf.keras.Model = tf.keras.models.model_from_config(config.model_config)
model: base_model.BaseModel = keras.models.model_from_config(config.model_config)
model.make(batch_size=1)
model.load_weights(h5, by_name=file_util.is_hdf5_filepath(h5), skip_mismatch=False)
model.summary()
Expand All @@ -44,7 +46,15 @@ def main(
signal = tf.reshape(signal, [1, -1])
signal_length = tf.reshape(tf.shape(signal)[1], [1])

outputs = model.recognize(schemas.PredictInput(signal, signal_length))
outputs = model.recognize(
schemas.PredictInput(
inputs=signal,
inputs_length=signal_length,
previous_tokens=model.get_initial_tokens(),
previous_encoder_states=model.get_initial_encoder_states(),
previous_decoder_states=model.get_initial_decoder_states(),
)
)
print(outputs.tokens)
transcript = tokenizer.detokenize(outputs.tokens)[0].numpy().decode("utf-8")

Expand Down
178 changes: 89 additions & 89 deletions examples/inferences/rnn_transducer.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,89 @@
# Copyright 2020 Huy Le Nguyen (@nglehuy)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

from tensorflow_asr.utils import data_util, env_util, math_util

logger = env_util.setup_environment()
import tensorflow as tf

parser = argparse.ArgumentParser(prog="Rnn Transducer non streaming")

parser.add_argument("filename", metavar="FILENAME", help="audio file to be played back")

parser.add_argument("--config", type=str, default=None, help="Path to rnnt config yaml")

parser.add_argument("--saved", type=str, default=None, help="Path to rnnt saved h5 weights")

parser.add_argument("--beam_width", type=int, default=0, help="Beam width")

parser.add_argument("--timestamp", default=False, action="store_true", help="Return with timestamp")

parser.add_argument("--device", type=int, default=0, help="Device's id to run test on")

parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu")

parser.add_argument("--subwords", default=False, action="store_true", help="Path to file that stores generated subwords")

parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")

args = parser.parse_args()

env_util.setup_devices([args.device], cpu=args.cpu)

from tensorflow_asr.configs import Config
from tensorflow_asr.features.speech_featurizers import SpeechFeaturizer, read_raw_audio
from tensorflow_asr.models.transducer.rnnt import RnnTransducer
from tensorflow_asr.tokenizers import CharTokenizer, SentencePieceTokenizer, SubwordFeaturizer

config = Config(args.config)
speech_featurizer = SpeechFeaturizer(config.speech_config)
if args.sentence_piece:
logger.info("Loading SentencePiece model ...")
text_featurizer = SentencePieceTokenizer(config.decoder_config)
elif args.subwords:
logger.info("Loading subwords ...")
text_featurizer = SubwordFeaturizer(config.decoder_config)
else:
text_featurizer = CharTokenizer(config.decoder_config)
text_featurizer.decoder_config.beam_width = args.beam_width

# build model
rnnt = RnnTransducer(**config.model_config, vocab_size=text_featurizer.num_classes)
rnnt.make(speech_featurizer.shape)
rnnt.load_weights(args.saved, by_name=True, skip_mismatch=True)
rnnt.summary()
rnnt.add_featurizers(speech_featurizer, text_featurizer)

signal = read_raw_audio(args.filename)
features = speech_featurizer.tf_extract(signal)
input_length = math_util.get_reduced_length(tf.shape(features)[0], rnnt.time_reduction_factor)

if args.beam_width:
transcript = rnnt.recognize_beam(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...]))
logger.info("Transcript:", transcript[0].numpy().decode("UTF-8"))
elif args.timestamp:
transcript, stime, etime, _, _, _ = rnnt.recognize_tflite_with_timestamp(
signal=signal,
predicted=tf.constant(text_featurizer.blank, dtype=tf.int32),
encoder_states=rnnt.encoder.get_initial_state(),
prediction_states=rnnt.predict_net.get_initial_state(),
)
logger.info("Transcript:", transcript)
logger.info("Start time:", stime)
logger.info("End time:", etime)
else:
transcript = rnnt.recognize(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...]))
logger.info("Transcript:", transcript[0].numpy().decode("UTF-8"))
# # Copyright 2020 Huy Le Nguyen (@nglehuy)
# #
# # Licensed under the Apache License, Version 2.0 (the "License");
# # you may not use this file except in compliance with the License.
# # You may obtain a copy of the License at
# #
# # http://www.apache.org/licenses/LICENSE-2.0
# #
# # Unless required by applicable law or agreed to in writing, software
# # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# # See the License for the specific language governing permissions and
# # limitations under the License.

# import argparse

# from tensorflow_asr.utils import data_util, env_util, math_util

# logger = env_util.setup_environment()
# import tensorflow as tf

# parser = argparse.ArgumentParser(prog="Rnn Transducer non streaming")

# parser.add_argument("filename", metavar="FILENAME", help="audio file to be played back")

# parser.add_argument("--config", type=str, default=None, help="Path to rnnt config yaml")

# parser.add_argument("--saved", type=str, default=None, help="Path to rnnt saved h5 weights")

# parser.add_argument("--beam_width", type=int, default=0, help="Beam width")

# parser.add_argument("--timestamp", default=False, action="store_true", help="Return with timestamp")

# parser.add_argument("--device", type=int, default=0, help="Device's id to run test on")

# parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu")

# parser.add_argument("--subwords", default=False, action="store_true", help="Path to file that stores generated subwords")

# parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")

# args = parser.parse_args()

# env_util.setup_devices([args.device], cpu=args.cpu)

# from tensorflow_asr.configs import Config
# from tensorflow_asr.features.speech_featurizers import SpeechFeaturizer, read_raw_audio
# from tensorflow_asr.models.transducer.rnnt import RnnTransducer
# from tensorflow_asr.tokenizers import CharTokenizer, SentencePieceTokenizer, SubwordFeaturizer

# config = Config(args.config)
# speech_featurizer = SpeechFeaturizer(config.speech_config)
# if args.sentence_piece:
# logger.info("Loading SentencePiece model ...")
# text_featurizer = SentencePieceTokenizer(config.decoder_config)
# elif args.subwords:
# logger.info("Loading subwords ...")
# text_featurizer = SubwordFeaturizer(config.decoder_config)
# else:
# text_featurizer = CharTokenizer(config.decoder_config)
# text_featurizer.decoder_config.beam_width = args.beam_width

# # build model
# rnnt = RnnTransducer(**config.model_config, vocab_size=text_featurizer.num_classes)
# rnnt.make(speech_featurizer.shape)
# rnnt.load_weights(args.saved, by_name=True, skip_mismatch=True)
# rnnt.summary()
# rnnt.add_featurizers(speech_featurizer, text_featurizer)

# signal = read_raw_audio(args.filename)
# features = speech_featurizer.tf_extract(signal)
# input_length = math_util.get_reduced_length(tf.shape(features)[0], rnnt.time_reduction_factor)

# if args.beam_width:
# transcript = rnnt.recognize_beam(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...]))
# logger.info("Transcript:", transcript[0].numpy().decode("UTF-8"))
# elif args.timestamp:
# transcript, stime, etime, _, _, _ = rnnt.recognize_tflite_with_timestamp(
# signal=signal,
# predicted=tf.constant(text_featurizer.blank, dtype=tf.int32),
# encoder_states=rnnt.encoder.get_initial_state(),
# prediction_states=rnnt.predict_net.get_initial_state(),
# )
# logger.info("Transcript:", transcript)
# logger.info("Start time:", stime)
# logger.info("End time:", etime)
# else:
# transcript = rnnt.recognize(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...]))
# logger.info("Transcript:", transcript[0].numpy().decode("UTF-8"))
Loading

0 comments on commit 63a400c

Please sign in to comment.