Following our exploration of quantization and its impact on model efficiency and size, we now delve into another crucial technique for optimizing machine learning models — pruning. This method focuses on reducing the model size by eliminating less significant weights, which are less likely to affect the model’s output accuracy. By making the model sparser, pruning not only decreases the model’s footprint but also enhances its performance on edge devices where storage and computational resources are limited.
Pruning works by zeroing out low-magnitude (insignificant) weights within the network. The rationale behind this technique is that such weights have minimal impact on the predictive power of the model, thus, removing them won’t significantly alter the output. The TensorFlow Model Optimization Toolkit offers a convenient method for implementing pruning, known as prune_low_magnitude
, which integrates seamlessly with Keras models.
To apply pruning, the toolkit wraps the model layers with the prune_low_magnitude
function, enabling the dynamic zeroing of weights during training. This approach involves defining a pruning schedule, typically using a PolynomialDecay
strategy that gradually increases the model’s sparsity over the training epochs.
import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot# Get the pruning method from TensorFlow Model Optimization Toolkit
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
# Set up parameters for training and pruning
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of the training set will be used for validation.
# Calculate the total number of training images considering the validation split
num_images = train_images.shape[0] * (1 - validation_split)
# Calculate the end step to finish pruning after 2 epochs
# np.ceil() rounds up to ensure all images are processed
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
# Define the pruning schedule using Polynomial Decay
# Initial sparsity is 50% which means half of the weights are set to zero initially
# Final sparsity is 80% which means by the end of training, 80% of the weights will be zero
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.50,
final_sparsity=0.80,
begin_step=0,
end_step=end_step
)
}
# Apply pruning configuration to the baseline model
model_for_pruning = prune_low_magnitude(baseline_model, **pruning_params)
# Recompile the model since the pruning layers require setting up new operations
model_for_pruning.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Print the model summary to see the changes and the added pruning layers
model_for_pruning.summary()
You can also peek at the weights of one of the layers in your model. After pruning, you will notice that many of these will be zeroed out.
# Preview model weights
model_for_pruning.weights[1]
Once the model is set up for pruning, it undergoes re-training where the pruning wrappers dynamically adjust the weights to achieve the desired sparsity level. This process significantly reduces the model’s size while aiming to preserve its original accuracy.
# Callback to update pruning wrappers at each step
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
]# Train and prune the model
model_for_pruning.fit(train_images, train_labels,
epochs=epochs, validation_split=validation_split,
callbacks=callbacks)
After the baseline model undergoes pruning to create the model_for_pruning
, this pruned model is then used for training. As a result of the pruning strategy implemented in the second step, many of the model’s weights will be set to zero.
model_for_pruning.weights[1]
After training, the pruning wrappers need to be removed to streamline the model architecture for deployment. This step involves stripping the model of its pruning layers, which leaves it with the same architecture as the baseline but with many weights set to zero.
# Stripping Pruning Layers
from tensorflow_model_optimization.sparsity.keras import strip_pruningmodel_for_export = strip_pruning(model_for_pruning)
model_for_export.summary()
Pruning offers several advantages for deploying models to mobile and edge devices:
- Reduced Model Size: Pruning significantly cuts down the number of parameters, making the model lighter and faster to load and execute.
- Enhanced Compression: Sparse models resulting from pruning can be compressed more effectively, further reducing the storage space required.
- Maintained Performance: Properly implemented pruning retains the model’s accuracy, ensuring it remains effective in its application.
For further optimization, the pruned model can be quantized. This combination of pruning and quantization can lead to substantial reductions in model size — often compressing the model to about 10 times smaller than the original size, making it ideal for deployment in resource-constrained environments.
# Convert and quantize the pruned model.
pruned_quantized_tflite = convert_tflite(model_for_export, FILE_PRUNED_QUANTIZED_TFLITE, quantize=True)# Compress and get the model size
MODEL_SIZE['pruned quantized tflite'] = get_gzipped_model_size(FILE_PRUNED_QUANTIZED_TFLITE)
print_metric(MODEL_SIZE, "gzipped model size in bytes")
# Get accuracy of pruned Keras and TF Lite models
ACCURACY = {}_, ACCURACY['pruned model h5'] = model_for_pruning.evaluate(test_images, test_labels)
ACCURACY['pruned and quantized tflite'] = evaluate_tflite_model(FILE_PRUNED_QUANTIZED_TFLITE, test_images, test_labels)
print_metric(ACCURACY, 'accuracy')
As expected, the performance of both models is very close to each other.
Pruning is a powerful technique to optimize neural networks for mobile deployment, effectively reducing model size while preserving or even enhancing performance. When combined with quantization, pruning can make models exceptionally lightweight and fast, perfectly suited for mobile and other edge devices. This approach not only saves on computational resources but also ensures models are quicker to update and deploy, providing a seamless user experience on mobile platforms.
Be the first to comment