Whale identification – 5th place approach using siamese networks with adversarial training


  1. What’s it all about?
  2. Dataset overview
  3. Model in details
  4. Training design in details
  5. Ensembling and stackings
  6. Further discussion
  7. Credits

Code can be found on github here.

1. What’s it all about?

Building a global database of individual whales has been an ambitious mission for happywhale.com. For centuries, scientists have relied on the shape of the tail as well as its unique markings found in each whale photo to identify its species. And the dataset has so far been only works from individual’s manual labelings. The process has been tedious, and datasets are quite under-utilized.

This competition challenges us to design a robust network that is able to recognize the species of a whale, given its single image of tail (fluke) only.

The strong motivation behind me to join this competition was the opportunity to try one-short learning on a real world use case. In the end, it turned out the experience was wonderful – not only have I learnt so much through the process, but also that I managed to build up a powerful neural network model which achieved amazing accuracy of 0.96! (easily beats humans like me on this task :))

What’s more, it has been an incredible experience working together with my partner ZFTurbo on this task.

2. Dataset Overview

There are in total 25,361 images for training, and 7,960 images for test. Most of them are normal vertical images taken of a whale tail, such as below:

Sample images from competition dataset

Each of the image has relatively high resolution in its original format – e.g. 1050 x 600 in dimension.

There are, however, some outlier images, where either the photo dimension is very distorted, or the images are not displayed properly, for examples:

Outlier image 1 – dimension distorted
Outlier image 2 – photo not taken properly
Outlier image 3 – image not shown clearly

Since there are only minor problematic images like these, we will imply ignore them here, and assume that they will not have much impact on the models. But as a good practice of building models for your dataset, it is always useful to know more about your data first, even including those outliers :)

The challenging part of this task is the fact that its labeling distribution is extremely skewed.

Out of 25,361 whales, there are in total 5,005 classes. What’s more, 9,664 of all the whales belong to a class called new_whale, which you can think of as ‘un-identified whale’. The interesting thing about new_whale is that you cannot simply group those 9,664 images as one class for training – some of them may belong to one category, but it’s just we don’t know! So the easiest and safest way is to put them aside from our training, unless you have some smart way of using those data.

This leaves us to the rest of 15K useful training images, which fall into 5004 classes in total. The labeling distribution of the remaining 15K images are also very skewed, for example:

Class sizes distribution, source: https://www.kaggle.com/kretes/eda-distributions-images-and-no-duplicates

This shows that 2000 of the 5K classes have only one image. About 1250 classes have 2 images, and so on.

One may like to start with a typical multi-class classification ConvNets, or ImageNet pretrained models, to approach this task, which we have also explored at the beginning of the competition. However, the result was not good, and it is not easy to train the network well given the nature of this data. The labels distribution is so long-tailed, and most labels have only 1 or few images to train a classifier well on those classes.

On the other hand, you can also easily imagine that most of the ImageNet pretrained models were trained to learn on features such as trees, cars, cats, dogs, etc., but features that models need to learn to tackle this task well are those tiny little differences on whales flukes, such as below:

A good model needs to recognize those tiny little features on whale fluke in order to tell whether two whales are same

3. Model in details

3.1 Siamese Neural Network

One of the amazing idea for dealing with one-shot problem in such context is called Siamese Neural Networks (SNN).

Basically, a SNN trains a ConvNets model to extract domain specific features from the two input images, and use those extracted features to make decision on whether the two images belong to the same category (1) or different categories (0)

During training, we will always have pairs of images as input. We will feed those pairs into one same ConvNets layer, which outputs a fixed length feature vector (e.g. a 512 vector) to represent features of that image. Then we will calculate the distance, or similarity, between the two output feature vectors from each image, and output a probability (e.g. through a sigmoid layer) to indicate the chance of two images being in the same category.

Below shows the simple illustration of a SNN architecture.

SNN Prototype

3.2 Branch model and head model

3.21 Branch Model (ConvNets Layer)

As we have discussed above, the entire SNN architecture contains a branch model (ConvNets Layer) and head model (Distance Layer). Martin’s great solution shows detailed implementation of both models, which was used as part of our final solution, and I will not re-paste them here.

In addition to Martin’s modeling, we have also tried:

  • Using 3-channel (i.e. RGB) of input image
  • More dimensions from inputs (384, 512, 768) other than 384 only
  • Different tweaks of Convolutional layers, such as adding Squeeze-Excitation layers (SE-Net). This gives more model diversity in our final model ensemblings.

In addition to self-designed ConvNets, we have also tried using ImageNet pre-trained models, such as ResNet50, DenseNet and InceptionV3, as our branch models. This turned out to be very useful in our final ensembled submission, because ImageNet models tend to learn very differently from self-design ConvNets that are trained from scratch, and ensembling ImageNets with customized CNNs gives quite a bit of boost on LB score.

In fact, pretrained DenseNet121 was the best performing model, which scored 0.959 MAP5 score on both public and private LB. And our best self-design ConvNets from scratch scored 0.957 on public LB / 0.954 on private LB. Both reported scores are for single model inference using k-fold approach (details in below).

To use ImageNet models, simply replace the branch model code with below:

ConvModel = InceptionV3(input_shape=img_shape, weights='imagenet') 
x = ConvModel.layers[-2].output 
branch_model  = Model(ConvModel.input, x) 

3.22 Head Model (Distance Layer)

The head model of Martin’s implementation is great, as it is using a complicated neural network to predict the similarity between two feature vectors, comparing to using a simpler distance metric such as L1 or L2, which are more popularly selected in models such as triplet model.

I suspect that the advantage of using simpler approach, such as L2 distance, is that it is much faster to calculate the results, compared to running a full forward pass using neural network. This would be an advantage when we try to scale up on large scale data in production. But here for this competition, wen would aim for the accuracy.

3.3 Bounding box model

It has been discussed in Martin’s forum that using cropped image of whale fluke could give better results than simply feeding the original image during training. This is easy to understand as by feeding the cropped images to the model, we removed majority of the background noise, such as water, flying bird, etc, and let the model concentrate on the more important details, which could improve the model’s performances given the amount of training data is relatively small.

We cropped the tails out of the original image

The bounding box model is relatively easy to implement. We just build a simpler ConvNets layer (similar to VGG16 architecture) and let it output four floating numbers, corresponding to the x and y of two corners of the rectangular box.

The dataset we used for training the bounding box model can be downloaded from here. Each row of this file contains x and y coordinates of many feature points of the whale fluke, and we just need to find the max and min for both x and y to get the two corner points in order to crop out the box, i.e. (max_x, max_y, min_x, min_y). However, we should be aware that this dataset are based on images from the Playground version Humpback Competition, so you will need to use playground data to train the model, and the image names might not be found in the latest Humpback Competition Dataset.

One thing to also note here is that there might be some inaccuracies from the raw predictions of the bounding box model over the entire dataset. This is due to the simplicity of the model, as well as to some of the outlier images existing in the training dataset. We have tried to build a more advanced model (such as using ImageNet’s ResNet50), and removed those outlier images, and indeed have noticed an improvement of accuracy of bounding box predictions.

4. Training design in details

4.1 Image augmentation

Almost all CV models rely on good ways of doing augmentation. This is more true when we only have around 15K training images in total here.

We adopted on-the-fly image augmentation, and used tricks like shear, rotation, flipping, contrast, Gaussian noise, blurring, color augmentations, greying, random crops, and so on. We noticed data augmentation helps a lot during training, especially for ImageNet pre-trained models, as those models can easily remember all images within few epochs (i.e. learns too fast and overfits), if augmentation is not used.

A good set of augmentations helps boost the score by around 0.01 – 0.02. Shown below is a set of randomly augmented images during training.

original image (bottom right) as well as some of its augmented images during training.

4.2 Adversarial training procedures

The training was the most fun part of this competition.

There are close to infinite number of unmatched image pairs we can create out of the 15K. If we only randomly sample a subset as negative examples to train the model , the model will soon become good enough to have close to 0 log-loss for most of them. Most of those negative pairs would become redundant for training.

However, there still exist difficult unmatched pairs that are not easy for the model to differentiate between. As shown in Figure below:

‘Easy’ v.s. ‘Hard’ cases of unmatched training pairs

A good training strategy would be constantly to ‘mine’ out those difficult cases in each epoch, based on the current performance of model’s snapshot, so that the model will always have certain percentage of hard cases in the training loop from which it still struggles to tell a difference. This is similar to the triplet mining in FaceNet.

At the beginning of training, we need to feed the model with easier cases by randomly sampling matching and un-matching pairs. This is because if we start by feeding model with hard samples – unmatched pairs that appear to be ‘more similar’ than matched pairs – the model may learn to predict similar images as different whales (0), and dissimilar images as same whales (1). The result will be that the training loss will ‘diverge’ rather than ‘converge’, and the training fails eventually.

As the training progresses and model becomes stronger, we start to feed in more difficult negative pairs.

Since any two whales from difference classes can form an un-matched (negative) pair, and only whales from same class can form matched (positive) pair, there are exponentially more un-matched pairs than matched pairs that we can form during training. Due to this limitation, we are only forming adversarial pairs for un-matched whales, and not for matched whales. We will only be using random sampling for matched whales (positive pairs) instead.

Just to add in: I have also tried to use adversarial training for matching whales, but due to this limitation of number of positive cases, there was little / no performance gain for model training, compared to random sampling.

Below shows the algorithm for negative pairs adversarial mining:

  • Use trained model to calculate a train vs train scoring matrix T, where each element Tij is a predicted probability by the model for image i and image j.
  • Invert all values for T to negatives. (i.g. between -1, 0)
  • Find all values in T that correspond to matching pairs, set their values to a max positive number (e.g. 10000), so that those pairs will not be selected later on.
  • For each value of matrix T, add up with K * np.random.uniform(0, 1), where K decreases from 1000 to as small as need (i.e. 0.05) as model trains for more epochs. For example, if we don’t use K, (i.e. set K = 0), we will be mining the deterministically hardest pairs to train the model, which is not ideal.
  • Using linear assignment problem calculation to calculate for every row i, its corresponding pair from col j, so that the total cost (all values add up) would be smallest. We will then use the selected (i,j) pairs for next epoch of training.

For matching pairs, we are doing the simplest random sampling, to get derangements. This is basically to make sure that no images will be paired up with itself, and every image will be used exactly once for forming pairs.

Eventually, with linear assignment calculation as well as derangements, we are getting 50% of matching samples, as well as 50% unmatching samples, for each epoch of training.

4.3 Cross validation and K-fold training

Our training procedures follow a 4-fold stratified cross validation approach: With all training data we have, we first do a random stratified splits – randomly split data into 4 folds based on each class.

However, we only use classes that contain at least 2 images for training, this is because single image class cannot form matching pairs. And for classes that have exactly two images, they will have to always be in training fold, and won’t be split up. So this means all training classes have at least two images. Finally for classes that have 3, 4, or more images, they will be split into four folds as evenly as possible.

For each validation fold, we will add in around 25% of images randomly selected from the ‘new whale’ images pools, to make it similar to percentage of LB set. Training fold doesn’t contain any ‘new_whale’ images, in order to fulfill our training logic.

During training, the validation fold serves as a very good indicator for early stopping. We will monitor MAP5 scores on validation for around every 20 epochs. We selected the best model weights snapshots based on validation MAP5 scores as well.

In fact, some models are quickly discarded for a few round of epochs of training, if their validation scores turned out to be quite bad. This improved the efficiency for testing many models within short period.

During inference, we ran predictions using each of the 4 folds on the same test data, and average the four test-vs-train score matrices to get our final score matrix, which we will then calculate the final 5-output submission file and submit on LB.

Using k-fold approach is a great way to 1) stabilize the test-vs-train score matrices 2) serve as ensemble to improve the performance 3) early stop to prevent overfitting 4) quickly test and compare models.

In addition, this k-fold helps us to generate additional validation prediction data, which is very useful for local validation for comparing models, as well as utilizing final stacking strategy.

Finally, we noticed that our local scores (LS) have a very consistent gap with LB scores – our LS is always 0.003~0.004 higher than public LB scores.

4.4 Bootstrapped data / pseudo labeling

As Martin’s has mentioned in this post of the bootstrapped method to improve model’s performance, we have also adopted similar approach by using our best models’ prediction on test set (test vs train score matrix), and select the confident predictions (e.g. probability of 0.99999+) to convert into our new training data.

One thing to highlight is that we are changing our pseudo labeled dataset every time we have got an improved version of model (e.g. LB 0.938 -> LB 0.950 -> LB 0.965, etc. ) And every time we changed our pseudo labelled test data from this new model, we will have to re-create a new split of k-folds. This is simply because the dataset has changed, so the split will also need to reflect that. There was a time when we forgot to do this, we ended up using a messed up k-folds, and our local score strongly overfitted! We had to discarded the models and to redo the training from the beginning.

However, it is hard for us to say how much pseudo labeling actually helped or did it even help at all. This is because our model’s accuracy has reached at a high level (0.96+) so all the mis-classified cases are difficult cases even for humans, but the data we select for pseudo labeling are those that have a very high probability of being match, which means they are very simply cases.

But pseudo labeling did help us in a way that it generates more training pairs. For example, a single class image that was not being used for training could form a valid training pair if there happens to be a pseudo-labeled test set that falls into that same class!

5 Ensembling and stacking

Our final solution consists of around 15 ensembled predictions (i.e. test-vs-train score matrices). Some of these are from different models, such as ImageNets pre-trained models like densenet121, Inception V3, SE-ResNet, etc., or custom ConvNets architectures similar to the one designed by Martin in his post.

However, we have also included in our ensembles of those that are from same model architectures, but were trained using different strategies, such as different input image sizes, different augmentation strategies, and different initializations (i.e. some models are pre-trained using Playground dataset before training with new dataset). To ensemble all the models simply means to average those score matrices from each model, and use the averaged final score matrix to generate final submission file.

However, we have also explored the technique of stacking – basically for each image pair, we use all 15 models’ predictions to form a new point of 15 dimension. Then we will end up with roughly around train_num * train_num number of such data points. Since we know the ground truth of matching (1) or un-matching (0) for each data point, we will feed those points into a stacker model such xgboost or nnets, to train it as a second-level binary classifier.

Stacking workflow – dimensions are for illustrative purpose

6 Further discussion

To further investigate which area the siamese network is actually ‘looking at’ when presented with input images, we use some CNN visualization tools to actually help us understand:

Example 1 – input image with activation areas
Example 2 – input image with activation areas
Example 3 – input image with activation areas (minor mis-activations on object far away at the top)
Example 4 – input image with activation areas
Test data matching pairs identified by our model (first row: input image from test set with its activation areas, second row: the known image from training set that was identified by model as a match)

It is amazing to see that the model actually identified the correct features from the input images. Even though in some cases (like one in example 3 above), it could mis-identify some features such as an distant object, or a flying bird, but overall the features extracted are correct enough to justify a valid prediction for most images!

We have shown here a successful use case of applying siamese neural networks for one-shot / metric learning, and I have no doubt such applications would have far wider impact on more cases, where datasets are inheritedly long tailed, sparse, and objects that need to be identified are fine grained like in cases of whale flukes.

7. Credits



Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this:
search previous next tag category expand menu location phone mail time cart zoom edit close