PyTorch conversion of the excellent post on the same topic in Tensorflow. Simply an implementation of a triple loss with online mining of candidate triplets used in semi-supervised learning.
pip install online_triplet_loss
Then import with:
from online_triplet_loss.losses import *
PS: Requires Pytorch version 1.1.0 or above to use.
from torch import nn
import torch
model = nn.Embedding(10, 10)
#from online_triplet_loss.losses import *
labels = torch.randint(high=10, size=(5,)) # our five labels
embeddings = model(labels)
print('Labels:', labels)
print('Embeddings:', embeddings)
loss = batch_hard_triplet_loss(labels, embeddings, margin=100)
print('Loss:', loss)
loss.backward()
#from online_triplet_loss.losses import *
embeddings = model(labels)
print('Labels:', labels)
print('Embeddings:', embeddings)
loss, fraction_pos = batch_all_triplet_loss(labels, embeddings, squared=False, margin=100)
print('Loss:', loss)
loss.backward()
References
- Triplet Loss and Online Triplet Mining in Tensorflow
- Facenet paper
- adambielski's nice implementation (unfortunately context switches between CPU / GPU)