The Whisper model, initially developed in PyTorch, has been ported to TensorFlow as well. One notable example is Hugging Face’s TFWhisperForConditionalGeneration model, which derives from TFPreTrainedModel and simultaneously acts as a tf.keras.Model subclass. Let’s delve into the concise code that generates the final .tflite file, enabling its seamless integration into an Android application.
The conversion is done inside Google’s Colaboratory. This allows anybody to write and execute arbitrary python code through the browser, and is especially well suited to machine learning, data analysis and education. More technically, Colab is a hosted Jupyter notebook service that requires no setup to use, while providing access free of charge to computing resources including GPUs.
First we install datasets library and TensorFlow 2.14.0 (this version gives a successful conversion and a working .tflite file with the TensorFlow Interpreter as of the time this article was written):
!pip install datasets
!pip install tensorflow==2.14.0
then we import the libraries, load the model, do the inference and save it in save model format:
import tensorflow as tf
import transformersfrom datasets import load_dataset
from transformers import WhisperProcessor, WhisperFeatureExtractor, TFWhisperForConditionalGeneration, WhisperTokenizer
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny.en")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en", predict_timestamps=True)
processor = WhisperProcessor(feature_extractor, tokenizer)
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
# Loading dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
inputs = feature_extractor(
ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="tf"
)
input_features = inputs.input_features
# Generating Transcription
generated_ids = model.generate(input_features=input_features)
print(generated_ids)
transcription = processor.tokenizer.decode(generated_ids[0])
print(transcription)
# Save the model
model.save('/content/tf_whisper_saved')
Define a model whose serving function is the generate call. That way we have specific inputs and outputs:
class GenerateModel(tf.Module):
def __init__(self, model):
super(GenerateModel, self).__init__()
self.model = model@tf.function(
input_signature=[
tf.TensorSpec((1, 80, 3000), tf.float32, name="input_features"),
],
)
def serving(self, input_features):
outputs = self.model.generate(
input_features,
# change below if you think your output will be bigger
# aka if you have bigger transcriptions
# you can make it 200 for example
max_new_tokens=100,
return_dict_in_generate=True,
)
return {"sequences": outputs["sequences"]}
saved_model_dir = '/content/tf_whisper_saved'
tflite_model_path = 'whisper_english.tflite'
generate_model = GenerateModel(model=model)
tf.saved_model.save(generate_model, saved_model_dir, signatures={"serving_default": generate_model.serving})
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
# Learn about post training quantization
# https://www.tensorflow.org/lite/performance/post_training_quantization
# Dynamic range quantization which reduces the size of the model to 25%
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# Float16 quantization reduces the size to 50%
#converter.target_spec.supported_types = [tf.float16]
tflite_model = converter.convert()
# Save the model
with open(tflite_model_path, 'wb') as f:
f.write(tflite_model)
You can view the structure and the inputs/outputs of the model using netron.app.
Use TensorFlow Lite Interpreter to check the transcription:
# loaded model... now with generate!
tflite_model_path = 'whisper_english.tflite'
interpreter = tf.lite.Interpreter(tflite_model_path)tflite_generate = interpreter.get_signature_runner()
generated_ids = tflite_generate(input_features=input_features)["sequences"]
print(generated_ids)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(transcription)
In the end we have a fully functional .tflite file that can be used inside an android application.
The above work has been accelerated by the awesome work that has been done at the below repos:
Conclusion
This effort was done to enhance the on-device Speech-To-Text task. Models that run locally provide improved privacy protection and enhanced reliability(processing voice data with no internet connectivity), reduce latency (providing faster and more responsive results) and minimize cost (no API usage). As the demand for voice assistants and on-device transcription grows, on-device STT offline is poised to play an increasingly important role in the future of speech recognition.
Be the first to comment