A beginner’s tutorial to train a classifier model with unlabeled data using semi-supervised learning (SSL)

Traditionally, training computer vision models like classifiers required labeled data. Each example in the training data needed to be a pair: an image, and a human-generated label describing the image.

Recently, new SSL techniques have delivered the most accurate models in computer vision for classic challenges like Imagenet. Semi-supervised learning (SSL) lets a model learn from both labeled and unlabeled data. Unlabeled data consists solely of images, without any labels.

SSL is great because there is usually a lot more unlabeled data than labeled, especially once you deploy a model into production. Also, SSL reduces the time, cost, and effort of labeling.

But how does a model learn from images without labels? The key insight is that images themselves have information. The magic of SSL is that it can extract information from unlabeled data by automatically clustering images that are similar based on their structure, and this clustering provides additional information for a model to learn from.

This tutorial uses several common Python libraries included in Google Colab, including matplotlib, numpy, and TensorFlow. If you need to install them, you can usually run !pip install --upgrade pip; pip install matplotlib numpy tensorflow within a Jupyter notebook or pip install --upgrade pip; pip install matplotlib numpy tensorflow from the command line (no exclamation point).

If you are using Google Colab, make sure to change the runtime type to GPU.

For this tutorial, let’s train a classifier on the CIFAR-10 dataset. This is a classic research dataset of natural images. Let’s load it up and take a look. We’ll see some of the classes in CIFAR-10: frog, boat, car, truck, deer, horse, bird, cat, dog, and airplane.

import matplotlib.pyplot as plt

def plot_images(images):
  """Simple utility to render images."""
  # Visualize the data.
  _, axarr = plt.subplots(5, 5, figsize=(15,15))

  for row in range(5):
    for col in range(5):
      image = images[row*5 + col]
      axarr[row, col].imshow(image)
      
import tensorflow as tf

NUM_CLASSES = 10
# Load the data using the Keras Datasets API. 
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

plot_images(x_test)
Sample Image

Create the Model

In general, you’ll want to use a model architecture off the shelf. This saves you the effort of fiddling with model architecture design. The general rule of model sizing is to pick a model that’s big enough to handle your data, but not so big that it’s slow during inference time. For a very small dataset like CIFAR-10, we’ll use a very small model. For larger datasets with larger image sizes, the Efficient Net family is a good choice.

def get_model():    
    return tf.keras.applications.MobileNet(input_shape=(32,32,3), 
                                           weights=None, 
                                           classes=NUM_CLASSES, 
                                           classifier_activation=None)

model = get_model()

Prepare the Data

Now, let’s prepare the data by converting the labels, which are integers from 0 to 9 representing the 10 classes of objects, into one-hot vectors like [1,0,0,0,0,0,0,0,0,0] and [0,0,0,0,0,0,0,0,0,1]. We’ll also update the image pixels to a range expected by the model architecture, namely, the range [-1, 1].

def normalize_data(x_train, y_train, x_test, y_test):
  """Utility to normalize the data into standard formats."""

  # Update the pixel range to [-1,1], which is expected by the model architecture.
  x_train = x = tf.keras.applications.mobilenet.preprocess_input(x_train)
  x_test = x = tf.keras.applications.mobilenet.preprocess_input(x_test)

  # Convert to one-hot labels. 
  y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)
  y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)

  return x_train, y_train, x_test, y_test
  
x_train, y_train, x_test, y_test = \
  normalize_data(x_train, y_train, x_test, y_test)

This dataset includes 50,000 examples. Let’s use 5,000 of them as labeled images, and 20,000 as unlabeled images.

import numpy as np

def prepare_data(x_train, y_train, num_labeled_examples, num_unlabeled_examples):
    """Returns labeled and unlabeled datasets."""
    num_examples = x_train.size

    assert num_labeled_examples + num_unlabeled_examples <= num_examples

    # Generate some random indices. 
    dataset_size = len(x_train)
    indices = np.array(range(dataset_size))
    generator = np.random.default_rng(seed=0)
    generator.shuffle(indices)

    # Split the indices into two sets: one for labeled, one for unlabeled. 
    labeled_train_indices = indices[:num_labeled_examples]
    unlabeled_train_indices = indices[num_labeled_examples : num_labeled_examples + num_unlabeled_examples]

    x_labeled_train = x_train[labeled_train_indices]
    y_labeled_train = y_train[labeled_train_indices]

    x_unlabeled_train = x_train[unlabeled_train_indices]
    # Since this is unlabeled, we won't need a y_labeled_data. 

    return x_labeled_train, y_labeled_train, x_unlabeled_train

NUM_LABELED = 5000
NUM_UNLABELED = 20000

x_labeled_train, y_labeled_train, x_unlabeled_train = \
    prepare_data(x_train, 
                 y_train, 
                 num_labeled_examples=NUM_LABELED, 
                 num_unlabeled_examples=NUM_UNLABELED)

del x_train, y_train

Baseline Training

To measure the performance improvements from SSL, let’s first measure the performance of the model with a standard training loop without SSL.

Let’s set up a standard training loop with some basic data augmentations. Data augmentation is a type of regularization, which fights overfitting and allows your model to generalize better to data it has never seen.

The hyperparameter values below (learning rate, epochs, batch size, etc) are a combination of common default values and manually tuned values.

The result is a model that’s about 45% accurate. (Remember to read the validation accuracy, not the training accuracy). Our next task will be figuring out if we can improve our model’s accuracy using SSL.

model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.CategoricalAccuracy()],
)

# Setup Keras augmentation. 
datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    featurewise_center=False,
    featurewise_std_normalization=False,
    horizontal_flip=True)

datagen.fit(x_labeled_train)

batch_size = 64
epochs = 30
model.fit(
    x = datagen.flow(x_labeled_train, y_labeled_train, batch_size=batch_size),
    shuffle=True,
    validation_data=(x_test, y_test),
    batch_size=batch_size,
    epochs=epochs,
)
baseline_metrics = model.evaluate(x=x_test, y=y_test, return_dict=True)
print('')
print(f"Baseline model accuracy: {baseline_metrics['categorical_accuracy']}")

Output:

Epoch 1/30
79/79 [==============================] - 4s 23ms/step - loss: 2.4214 - categorical_accuracy: 0.1578 - val_loss: 2.3047 - val_categorical_accuracy: 0.1000
Epoch 2/30
79/79 [==============================] - 1s 16ms/step - loss: 2.0831 - categorical_accuracy: 0.2196 - val_loss: 2.3063 - val_categorical_accuracy: 0.1000
Epoch 3/30
79/79 [==============================] - 1s 16ms/step - loss: 1.9363 - categorical_accuracy: 0.2852 - val_loss: 2.3323 - val_categorical_accuracy: 0.1000
Epoch 4/30
79/79 [==============================] - 1s 16ms/step - loss: 1.8324 - categorical_accuracy: 0.3174 - val_loss: 2.3496 - val_categorical_accuracy: 0.1000
Epoch 5/30
79/79 [==============================] - 1s 16ms/step - loss: 1.8155 - categorical_accuracy: 0.3438 - val_loss: 2.3339 - val_categorical_accuracy: 0.1000
Epoch 6/30
79/79 [==============================] - 1s 15ms/step - loss: 1.6477 - categorical_accuracy: 0.3886 - val_loss: 2.3606 - val_categorical_accuracy: 0.1000
Epoch 7/30
79/79 [==============================] - 1s 15ms/step - loss: 1.6120 - categorical_accuracy: 0.4100 - val_loss: 2.3585 - val_categorical_accuracy: 0.1000
Epoch 8/30
79/79 [==============================] - 1s 16ms/step - loss: 1.5884 - categorical_accuracy: 0.4220 - val_loss: 2.1796 - val_categorical_accuracy: 0.2519
Epoch 9/30
79/79 [==============================] - 1s 18ms/step - loss: 1.5477 - categorical_accuracy: 0.4310 - val_loss: 1.8913 - val_categorical_accuracy: 0.3145
Epoch 10/30
79/79 [==============================] - 1s 15ms/step - loss: 1.4328 - categorical_accuracy: 0.4746 - val_loss: 1.7082 - val_categorical_accuracy: 0.3696
Epoch 11/30
79/79 [==============================] - 1s 16ms/step - loss: 1.4328 - categorical_accuracy: 0.4796 - val_loss: 1.7679 - val_categorical_accuracy: 0.3811
Epoch 12/30
79/79 [==============================] - 2s 20ms/step - loss: 1.3962 - categorical_accuracy: 0.5020 - val_loss: 1.8994 - val_categorical_accuracy: 0.3690
Epoch 13/30
79/79 [==============================] - 1s 16ms/step - loss: 1.3271 - categorical_accuracy: 0.5156 - val_loss: 2.0416 - val_categorical_accuracy: 0.3688
Epoch 14/30
79/79 [==============================] - 1s 17ms/step - loss: 1.2711 - categorical_accuracy: 0.5374 - val_loss: 1.9231 - val_categorical_accuracy: 0.3848
Epoch 15/30
79/79 [==============================] - 1s 15ms/step - loss: 1.2312 - categorical_accuracy: 0.5624 - val_loss: 1.9006 - val_categorical_accuracy: 0.3961
Epoch 16/30
79/79 [==============================] - 1s 19ms/step - loss: 1.2048 - categorical_accuracy: 0.5720 - val_loss: 2.0102 - val_categorical_accuracy: 0.4102
Epoch 17/30
79/79 [==============================] - 1s 16ms/step - loss: 1.1365 - categorical_accuracy: 0.6000 - val_loss: 2.1400 - val_categorical_accuracy: 0.3672
Epoch 18/30
79/79 [==============================] - 1s 18ms/step - loss: 1.1992 - categorical_accuracy: 0.5840 - val_loss: 2.1206 - val_categorical_accuracy: 0.3933
Epoch 19/30
79/79 [==============================] - 2s 25ms/step - loss: 1.1438 - categorical_accuracy: 0.6012 - val_loss: 2.4035 - val_categorical_accuracy: 0.4014
Epoch 20/30
79/79 [==============================] - 2s 24ms/step - loss: 1.1211 - categorical_accuracy: 0.6018 - val_loss: 2.0224 - val_categorical_accuracy: 0.4010
Epoch 21/30
79/79 [==============================] - 2s 21ms/step - loss: 1.0425 - categorical_accuracy: 0.6358 - val_loss: 2.2100 - val_categorical_accuracy: 0.3911
Epoch 22/30
79/79 [==============================] - 1s 16ms/step - loss: 1.1177 - categorical_accuracy: 0.6116 - val_loss: 1.9892 - val_categorical_accuracy: 0.4285
Epoch 23/30
79/79 [==============================] - 1s 19ms/step - loss: 1.0236 - categorical_accuracy: 0.6412 - val_loss: 2.1216 - val_categorical_accuracy: 0.4211
Epoch 24/30
79/79 [==============================] - 1s 18ms/step - loss: 0.9487 - categorical_accuracy: 0.6714 - val_loss: 2.0135 - val_categorical_accuracy: 0.4307
Epoch 25/30
79/79 [==============================] - 1s 16ms/step - loss: 1.1877 - categorical_accuracy: 0.5876 - val_loss: 2.3732 - val_categorical_accuracy: 0.3923
Epoch 26/30
79/79 [==============================] - 2s 20ms/step - loss: 1.0639 - categorical_accuracy: 0.6288 - val_loss: 1.9291 - val_categorical_accuracy: 0.4291
Epoch 27/30
79/79 [==============================] - 2s 19ms/step - loss: 0.9243 - categorical_accuracy: 0.6882 - val_loss: 1.8552 - val_categorical_accuracy: 0.4343
Epoch 28/30
79/79 [==============================] - 1s 15ms/step - loss: 0.9784 - categorical_accuracy: 0.6656 - val_loss: 2.0175 - val_categorical_accuracy: 0.4386
Epoch 29/30
79/79 [==============================] - 1s 17ms/step - loss: 0.9316 - categorical_accuracy: 0.6800 - val_loss: 1.9916 - val_categorical_accuracy: 0.4305
Epoch 30/30
79/79 [==============================] - 1s 17ms/step - loss: 0.8816 - categorical_accuracy: 0.7054 - val_loss: 2.0281 - val_categorical_accuracy: 0.4366
313/313 [==============================] - 1s 3ms/step - loss: 2.0280 - categorical_accuracy: 0.4366

Baseline model accuracy: 0.436599999666214

Training with SSL

Now, let’s see if we can improve our model’s accuracy by adding unlabeled data to our training data. We will be using Masterful, a platform that implements SSL for computer vision models like our classifier.

Let’s install Masterful. In Google Colab, we can pip install from a notebook cell. We can also install it by the command line. For more details, see the Masterful installation guide.

!pip install --upgrade pip
!pip install masterful

import masterful

masterful = masterful.register()

Output:

Loaded Masterful version 0.4.1. This software is distributed free of
charge for personal projects and evaluation purposes.
See http://www.masterfulai.com/personal-and-evaluation-agreement for details.
Sign up in the next 45 days at https://www.masterfulai.com/get-it-now
to continue using Masterful.

Setup Masterful

Now, let’s setup some configuration parameters of Masterful.

# Start fresh with a new model
tf.keras.backend.clear_session()
model = get_model()

# Tell Masterful that your model is performing a classification task
# with 10 labels and that the image pixel range is 
# [-1,1]. Also, the model outputs logits rather than a softmax activation.
model_params = masterful.architecture.learn_architecture_params(
    model=model,
    task=masterful.enums.Task.CLASSIFICATION,
    input_range=masterful.enums.ImageRange.NEG_ONE_POS_ONE,
    num_classes=NUM_CLASSES,
    prediction_logits=True,
)

# Tell Masterful that your labeled training data is using one-hot labels. 
labeled_training_data_params = masterful.data.learn_data_params(
    dataset=(x_labeled_train, y_labeled_train),
    task=masterful.enums.Task.CLASSIFICATION,
    image_range=masterful.enums.ImageRange.NEG_ONE_POS_ONE,
    num_classes=NUM_CLASSES,
    sparse_labels=False,
)

unlabeled_training_data_params = masterful.data.learn_data_params(
    dataset=(x_unlabeled_train,),
    task=masterful.enums.Task.CLASSIFICATION,
    image_range=masterful.enums.ImageRange.NEG_ONE_POS_ONE,
    num_classes=NUM_CLASSES,
    sparse_labels=None,
)

# Tell Masterful that your test/validation data is using one-hot labels. 
test_data_params = masterful.data.learn_data_params(
    dataset=(x_test, y_test),
    task=masterful.enums.Task.CLASSIFICATION,
    image_range=masterful.enums.ImageRange.NEG_ONE_POS_ONE,
    num_classes=NUM_CLASSES,
    sparse_labels=False,
)

# Let Masterful meta-learn ideal optimization hyperparameters like
# batch size, learning rate, optimizer, learning rate schedule, and epochs.
# This will speed up training. 
optimization_params = masterful.optimization.learn_optimization_params(
    model,
    model_params,
    (x_labeled_train, y_labeled_train),
    labeled_training_data_params,
)

# Let Masterful meta-learn ideal regularization hyperparameters. Regularization
# is an important ingredient of SSL. Meta-learning can
# take a while so we'll use a precached set of parameters.
# regularization_params = \
#   masterful.regularization.learn_regularization_params(model, 
#                                                        model_params, 
#                                                        optimization_params, 
#                                                        (x_labeled_train, y_labeled_train),
#                                                        labeled_training_data_params)

regularization_params = masterful.regularization.parameters.CIFAR10_SMALL

# Let Masterful meta-learn ideal SSL hyperparameters. 
ssl_params = masterful.ssl.learn_ssl_params(
    (x_labeled_train, y_labeled_train),
    labeled_training_data_params,
    unlabeled_datasets=[((x_unlabeled_train,), unlabeled_training_data_params)],
)

Output:

MASTERFUL: Learning optimal batch size.
MASTERFUL: Learning optimal initial learning rate for batch size 256.

Train!

Now, we are ready to train using SSL techniques! We will call masterful.training.train, which is the entry point to Masterful’s training engine.

training_report = masterful.training.train(
    model,
    model_params,
    optimization_params,
    regularization_params,
    ssl_params,
    (x_labeled_train, y_labeled_train),
    labeled_training_data_params,
    (x_test, y_test),
    test_data_params,
    unlabeled_datasets=[((x_unlabeled_train,), unlabeled_training_data_params)],
)

Output:

MASTERFUL: Training model with semi-supervised learning enabled.
MASTERFUL: Performing basic dataset analysis.
MASTERFUL: Training model with:
MASTERFUL: 	5000 labeled examples.
MASTERFUL: 	10000 validation examples.
MASTERFUL: 	0 synthetic examples.
MASTERFUL: 	20000 unlabeled examples.
MASTERFUL: Training model with learned parameters partridge-boiled-cap in two phases.
MASTERFUL: The first phase is supervised training with the learned parameters.
MASTERFUL: The second phase is semi-supervised training to boost performance.
MASTERFUL: Warming up model for supervised training.
MASTERFUL: 	Warming up batch norm statistics (this could take a few minutes).
MASTERFUL: 	Warming up training for 500 steps.
100%|██████████| 500/500 [00:47<00:00, 10.59steps/s]
MASTERFUL: 	Validating batch norm statistics after warmup for stability (this could take a few minutes).
MASTERFUL: Starting Phase 1: Supervised training until the validation loss stabilizes...
Supervised Training: 100%|██████████| 6300/6300 [02:33<00:00, 41.13steps/s]
MASTERFUL: Starting Phase 2: Semi-supervised training until the validation loss stabilizes...
MASTERFUL: Warming up model for semi-supervised training.
MASTERFUL: 	Warming up batch norm statistics (this could take a few minutes).
MASTERFUL: 	Warming up training for 500 steps.
100%|██████████| 500/500 [00:23<00:00, 20.85steps/s]
MASTERFUL: 	Validating batch norm statistics after warmup for stability (this could take a few minutes).
Semi-Supervised Training: 100%|██████████| 11868/11868 [08:06<00:00, 24.39steps/s]

Analyzing the Results

The model you passed into masterful.training.train is now trained and updated in place, so you are able to evaluate it just like any other trained Keras model.

masterful_metrics = model.evaluate(
    x_test, y_test, return_dict=True, verbose=0
)
print(f"Baseline model accuracy: {baseline_metrics['categorical_accuracy']}")
print(f"Masterful model accuracy: {masterful_metrics['categorical_accuracy']}")

Output:

Baseline model accuracy: 0.436599999666214
Masterful model accuracy: 0.558899998664856

Visualizing the results

As you can see, you increased the accuracy rate from around 0.45 to 0.56. Of course, a more rigorous study would attempt to ablate other differences between the baseline training and training using SSL via the Masterful platform, as well as repeating the runs several times and generating error bars and p-values. For now, let’s make sure we plot this as a chart to help explain our results.

import matplotlib.cm as cm
from matplotlib.colors import Normalize
 
data = (baseline_metrics['categorical_accuracy'], masterful_metrics['categorical_accuracy'])
fig, ax = plt.subplots(1, 1)
   
ax.bar(range(2), data, color=('gray', 'red'))

plt.xlabel("Training Method")
plt.ylabel("Accuracy")

plt.xticks((0,1), ("baseline", "SSL with Masterful"))

plt.show()
Masterful Training Method

Conclusion

Congrats! We’ve just successfully employed SSL, one of the most advanced training methods available, to improve your model accuracy in a simple tutorial. Along the way, you avoided the cost and effort of labeling.

SSL doesn’t just work for classification – various flavors work for just about any computer vision task. To go deeper into the subject and see SSL in action for Object Detection, check out additional tutorials here.