Triplet Loss in Tensorflow Similarity | by Ridadogrul | Sep, 2023

Tensorflow Similarity

TensorFlow Similarity is a TensorFlow library for similarity learning which includes techniques such as self-supervised learning, metric learning, similarity learning, and contrastive learning.

With Tensorflow Similarity you can train two main types of models:

Self-supervised models: Used to learn general data representations on unlabeled data to boost the accuracy of downstream tasks where you have few labels.
Similarity models: Output embeddings that allow you to find and cluster similar examples such as images representing the same object within a large corpus of examples.

Use pip to install the library.

import tensorflow_similarity as tfsim # main package
except ModuleNotFoundError:
! pip install tensorflow_similarity
import tensorflow_similarity as tfsim

Triplet Loss

Triplet Loss architecture helps us to learn distributed embedding by the notion of similarity and dissimilarity. It’s a kind of neural network architecture where multiple parallel networks are trained that share weights among each other. During prediction time, input data is passed through one network to compute distributed embeddings representation of input data.Mathematically, the loss value can be calculated as L=max(d(a, p) — d(a, n) + m, 0)where:

p, positive, is a sample that has the same label as a, i.e., anchor,
n, negative, is another sample that has a label different from a,
d is a function to measure the distance between these three samples,
and m is a margin value to keep negative samples far apart.

Triplet Loss and Comparative Loss are two popular loss functions used in deep learning for tasks such as face recognition and image classification.

Although both loss functions aim to ensure that similar samples are closer to each other in the vector space than dissimilar samples, they differ in their approach and behavior.

Triple Loss is less greedy than Comparative Loss. It tries to maintain a margin between the distances of negative pairs and the distances of positive pairs. This means that Triple Loss can continue to organize the vector space in a better state even after reaching the local minimum.

Comparative Loss, on the other hand, only takes into account the margin value when comparing different pairs and does not care at all where similar pairs are at that moment. This means that the Comparative Loss may reach a local minimum before the Triple Loss.

Triple Loss also has the advantage of being able to tolerate some intraclass variance, unlike Adversarial Loss, which essentially forces the distance between an anchor and any positive to 0. This allows Triple Loss to expand the clusters to include outliers. still provides a margin between samples from different clusters.

In short, the choice between Triple Loss and Comparative Loss depends on the specific requirements of the task and the desired behavior of the loss function.

def triplet_loss(embeddings, labels):
"""Calculates the triplet loss.
embeddings: A tensor of shape (batch_size, embedding_size).
labels: A tensor of shape (batch_size,).
A scalar tensor representing the triplet loss.
# Step 1: Compute the pairwise distance matrix.
distance_matrix = tfsim.losses.pairwise_distance(embeddings, embeddings)
# Step 2: Compute the triplet loss.
loss = tfsim.losses.TripletLoss(distance_matrix, labels)
return loss

Invalid triplet masking

We can apply broadcasting to enumerate distance differences for all possible triplets and represent them in a tensor of shape (batch_size, batch_size, batch_size). However, only a subset of these n³ triplets are actually valid, and we need a corresponding mask to compute the loss value correctly. We will implement such a helper function in three steps:

Compute a mask for distinct indices, e.g., (i != j and j != k).

Compute a mask for valid anchor-positive-negative triplets, e.g., labels[i] == labels[j] and labels[j] != labels[k].

Combine two masks.

def get_triplet_mask(labels):
"""compute a mask for valid triplets
labels: Batch of integer labels. shape: (batch_size,)
Mask tensor to indicate which triplets are actually valid. Shape: (batch_size, batch_size, batch_size)
A triplet is valid if:
`labels[i] == labels[j] and labels[i] != labels[k]`
and `i`, `j`, `k` are different.
# step 1 - get a mask for distinct indices
# shape: (batch_size, batch_size)
indices_equal = tf.eye(tf.shape(labels)[0], dtype=tf.bool)
indices_not_equal = tf.logical_not(indices_equal)
# shape: (batch_size, batch_size, 1)
i_not_equal_j = tf.expand_dims(indices_not_equal, 2)
# shape: (batch_size, 1, batch_size)
i_not_equal_k = tf.expand_dims(indices_not_equal, 1)
# shape: (1, batch_size, batch_size)
j_not_equal_k = tf.expand_dims(indices_not_equal, 0)
# Shape: (batch_size, batch_size, batch_size)
distinct_indices = tf.logical_and(tf.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)
# step 2 - get a mask for valid anchor-positive-negative triplets
# shape: (batch_size, batch_size)
labels_equal = tf.expand_dims(labels, 0) == tf.expand_dims(labels, 1)
# shape: (batch_size, 1, batch_size)
i_equal_j = tf.expand_dims(labels_equal, 1)
# shape: (batch_size, batch_size, 1)
i_equal_k = tf.expand_dims(labels_equal, 2)
valid_indices = tf.logical_and(i_equal_j, tf.logical_not(i_equal_k))

mask = tf.logical_and(distinct_indices, valid_indices)

return mask

Batch-all strategy for online triplet mining

Triplet Loss involves several strategies to form or select triplets, and the simplest one is to use all valid triplets that can be formed from samples in a batch. This can be achieved in four easy steps thanks to utility functions we’ve already implemented:

Get a distance matrix of all possible pairs that can be formed from embeddings in a batch.

Apply broadcasting to this matrix to compute loss values for all possible triplets.

Set loss values of invalid or easy triplets to 0.

Average the remaining positive values to return a scalar loss.

class BatchAllTripletLoss(tf.keras.losses.Loss):
def __init__(self, margin=1.0):
super(BatchAllTripletLoss, self).__init__()
self.margin = margin
def call(self, y_true, y_pred):
"""Computes the triplet loss.
y_true: A tensor of shape (batch_size, num_classes).
y_pred: A tensor of shape (batch_size, embedding_dim).
A scalar tensor representing the triplet loss.
# Step 1: Get distance matrix.
distance_matrix = tf.reduce_sum(
tf.square(y_pred - tf.expand_dims(y_pred, 1)), axis=2)
# Step 2: Compute loss values for all triplets by applying broadcasting to distance matrix.
anchor_positive_dists = tf.expand_dims(distance_matrix, 2)
anchor_negative_dists = tf.expand_dims(distance_matrix, 1)
triplet_loss = anchor_positive_dists - anchor_negative_dists + self.margin
# Step 3: Filter out invalid or easy triplets by setting their loss values to 0.
mask = get_triplet_mask(y_true)
triplet_loss *= mask
triplet_loss = tf.nn.relu(triplet_loss)
# Step 4: Compute scalar loss value by averaging positive losses.
num_positive_losses = tf.reduce_sum(tf.cast(triplet_loss > 0, tf.float32))
triplet_loss = tf.reduce_sum(triplet_loss) / (num_positive_losses + 1e-8)
return triplet_loss

Let’s examine how to implement tensorflow similarity triplet loss in a project.

Caltech-UCSD Birds 200 (CUB-200) is an image dataset with photos of 200 bird species (mostly North American). The total number of categories of birds is 200 and there are 6033 images in the 2010 dataset and 11,788 images in the 2011 dataset. Annotations include bounding boxes, segmentation labels.


(training_set, validation_set), dataset_info = tfds.load(
split=['train[:70%]', 'train[70%:]'],
num_training_examples = 0
num_validation_examples = 0
for example in training_set:
num_training_examples += 1
for example in validation_set:
num_validation_examples += 1
print('Total Number of Classes: {}'.format(num_classes))
print('Total Number of Training Images: {}'.format(num_training_examples))
print('Total Number of Validation Images: {} \n'.format(num_validation_examples))IMAGE_RES = 32
def format_image(image, label):
image = tf.image.resize(image, (IMAGE_RES, IMAGE_RES))/255.0
return image, label
train_batches = training_set.shuffle(num_training_examples//4).map(format_image).batch(BATCH_SIZE).prefetch(1)
validation_batches =

Data preprocessing step for training a deep learning model for image classification.
The function `format_image` takes an image and its corresponding label as input, and it resizes the image to a resolution of 32×32 pixels and normalizes the pixel values by dividing them by 255.0.
The `training_set` and `validation_set` are then used to create two batches of data, each with a batch size of 32. The `shuffle` method is used to randomly shuffle the training data before creating the batches, and the `prefetch` method is used to prefetch the next batch of data while the current batch is being processed.
This data preprocessing step is important for ensuring that the model is trained on data that is consistent in size and format, and it can help to improve the performance of the model.

def format_image(image, label):
image = tf.image.resize(image, (IMAGE_RES, IMAGE_RES))/255.0
return image, label
train_batches = training_set.shuffle(num_training_examples//4).map(format_image).batch(BATCH_SIZE).prefetch(1)
validation_batches =

Model definition

`SimilarityModel()` models extend `tensorflow.keras.model.Model` with additional features and functionality that allow you to index and search for similar looking examples.

def create_model():
inputs = tf.keras.layers.Input(shape=(32, 32, 3))
x = tf.keras.layers.experimental.preprocessing.Rescaling(1 / 255)(inputs)
x = tf.keras.layers.Conv2D(64, 3, activation="relu")(x)
x = tf.keras.layers.Conv2D(64, 3, activation="relu")(x)
x = tf.keras.layers.MaxPool2D()(x)
x = tf.keras.layers.Conv2D(128, 3, activation="relu")(x)
x = tf.keras.layers.Conv2D(128, 3, activation="relu")(x)
x = tf.keras.layers.Flatten()(x)
# smaller embeddings will have faster lookup times while a larger embedding will improve the accuracy up to a point.
outputs = tfsim.layers.MetricEmbedding(128)(x)
return tfsim.models.SimilarityModel(inputs, outputs)


Tensorflow similarity use an extended `compile()` method that allows you to optionally specify `distance_metrics` (metrics that are computed over the distance between the embeddings), and the distance to use for the indexer.

By default the `compile()` method tries to infer what type of distance you are using by looking at the first loss specified. If you use multiple losses, and the distance loss is not the first one, then you need to specify the distance function used as `distance=` parameter in the compile function.

model.compile(optimizer=tf.keras.optimizers.Adam(LR), loss=loss, metrics=['accuracy'])


Similarity models are trained like normal models.

EPOCHS = 50  # @param {type:"integer"}
history =,
plt.legend(["loss", "val_loss"])
plt.title(f"Loss: {} - LR: {LR}")


Indexing is where things get different from traditional classification models. Because the model learned to output an embedding that represent the example position within the learned metric space, we need a way to find which known example(s) are the closest to determine the class of the query example (aka nearest neighbors classication).

To do so, we are creating an index of known examples from all the classes present in the dataset. We do this by taking a total of 200 examples from the train dataset which amount to 20 examples per class and we use the index() method of the model to build the index.

class_list = list(range(20))  # List of class labels from 0 to 19
x_index, y_index = tfsim.samplers.select_examples(x_train, y_train, class_list, num_examples_per_class=20)
model.index(x_index, y_index, data=x_index)


To “classify” examples, we need to lookup their *k* [nearest neighbors]( in the index.

Here we going to query a single random example for each class from the test dataset using `select_examples()` and then find their nearest neighbors using the `lookup()` function.

# re-run to test on other examples
num_neighbors = 5

# select
x_display, y_display = tfsim.samplers.select_examples(x_test, y_test, class_list, 1)

# lookup nearest neighbors in the index
nns = model.lookup(x_display, k=num_neighbors)

# display
for idx in np.argsort(y_display):
tfsim_visualization.viz_neigbors_imgs(x_display[idx], y_display[idx], nns[idx], fig_size=(16, 2), cmap="Greys")


To be able to tell if an example matches a given class, we first need to `calibrate()` the model to find the optimal cut point. This cut point is the maximum distance below which returned neighbors are of the same class. Increasing the threshold improves the recall at the expense of the precision.

By default, the calibration uses the F-score classification metric to optimally balance out the precsion and recalll; however, you can speficy your own target and change the calibration metric to better suite your usecase.

num_calibration_samples = 1000
calibration = model.calibrate(
extra_metrics=["precision", "recall", "binary_accuracy"],

Metrics ploting

Let’s plot the performance metrics to see how they evolve as the distance threshold increases.

We clearly see an inflection point where the precision and recall intersect, however, this is not the `optimal_cutpoint` because the recall continues to increase faster than the precision decreases. Different usecases will have different performance profiles, which why each model needs to be calibrated.

fig, ax = plt.subplots()
x = calibration.thresholds["distance"]
ax.plot(x, calibration.thresholds["precision"], label="precision")
ax.plot(x, calibration.thresholds["recall"], label="recall")
ax.plot(x, calibration.thresholds["f1"], label="f1 score")
ax.set_title("Metric evolution as distance increase")

Precision/Recall curve

We can see in the precision/recall curve below, that the curve is not smooth.This is because the recall can improve independently of the precision causing a seesaw pattern.

Additionally, the model does extremly well on known classes and less well on the unseen ones, which contributes to the flat curve at the begining followed by a sharp decline as the distance threshold increases and examples are further away from the indexed examples.

fig, ax = plt.subplots()
ax.plot(calibration.thresholds["recall"], calibration.thresholds["precision"])
ax.set_title("Precision Recall Curve")


The purpose of `match()` is to allow you to use your similarity models to make classification predictions. It accomplishes this by finding the nearest neigbors for a set of query examples and returning an infered label based on neighbors labels and the matching strategy used (MatchNearest by default).

Note: unlike traditional models, the `match()` method potentially returns -1 when there are no indexed examples below the cutpoint threshold. The -1 class should be treated as “unknown”.

from tabulate import tabulate
num_matches = 10 # @param {type:"integer"}

# Make predictions on the test data
predictions = model.predict(x_test[:num_matches])

# Create a list to store the results
rows = []

# Compare the predicted labels to the expected labels
for idx in range(num_matches):
predicted_label = int(tf.argmax(predictions[idx]).numpy())
expected_label = int(y_test[idx])

# Check if the prediction is correct
correct = predicted_label == expected_label

rows.append([predicted_label, expected_label, correct])

# Print the results using tabulate
print(tabulate(rows, headers=["Predicted", "Expected", "Correct"]))

Confusion Matrix

Now that we have a better sense of what the match() method does, let’s scale up to a few thousand samples per class and evaluate how good our model is at predicting the correct classes.

As expected, while the model prediction performance is very good, its not competitive with a classification model. However this lower accuracy comes with the unique advantage that the model is able to classify classes that were not seen during training.

**NOTE** `tf.math.confusion_matrix` doesn’t support negative classes, so we are going to use **class 10 as our unknown class**. As mentioned earlier, unknown examples are any testing example for which the closest neighbor distance is greater than the cutpoint threshold.

# used to label in images in the viz_neighbors_imgs plots
# note we added a 11th classes for unknown
labels = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "Unknown"]
num_examples_per_class = 1000
cutpoint = "optimal"

x_confusion, y_confusion = tfsim.samplers.select_examples(x_test, y_test, class_list, num_examples_per_class)

matches = model.match(x_confusion, cutpoint=cutpoint, no_match_label=10)
cm = tfsim_visualization.confusion_matrix(
title="Confusion Matrix for cutpoint:%s" % cutpoint,

Index information

Following `model.summary()` you can get information about the index configuration and its performance using `index_summary()`.



Tensorflow Similarity:

Triplet Loss:


Source link

Be the first to comment

Leave a Reply

Your email address will not be published.