Saving and Loading Models Using TensorFlow 2.0+

How To Save And Load Your Deep Learning Models In TensorFlow 2

In this article, we will be discussing saving loading models using TensorFlow 2.0+. This is a beginner-intermediate level article meant for people who have just started out using TensorFlow for their deep learning projects.

Why do you need to save a model?

One of the very common mistakes people make as a beginner in deep learning is not saving their models.

Saving a deep learning model both during training and after training is a good practice. It saves your time and enhances the reproducibility of the model. Here are a few more reasons that you might consider for saving a model:

  • Training modern deep learning models with millions of parameters and huge datasets can be expensive in terms of computation and time. Moreover, you can get different results/accuracy during different training. So it is always a good idea to use a saved model for displaying your results rather than training on the spot.
  • Saving the different version of the same models allows you to inspect and understand the working of the model.
  • You can use the same compiled model in different languages and platforms that support TensorFlow eg.: TensorFlow Lite and TensorFlow JS without converting any of your code.

TensorFlow happens to offer a number of ways to save a model. We will be discussing all of them in detail in the next few sections.

How to save a model during training?

Sometimes it is important to save model weights during model training. If there has been an anomaly in your results after a certain epoch, with check-pointing it becomes easier to inspect the previous states of the model or even restore them.

TensorFlow models are trained using Model.train() function. We need to define a model checkpoint callback using tf.keras.callbacks.ModelCheckpoint() to tell the compiler to save model weights at certain intervals of epochs.

Callback sounds difficult but it is not difficult in term of usage. Here is an example of using it.

# This is the initialization block of code
# Not important for understanding the saving
# But to execute the next cells containing the code
# for saving and loading

import tensorflow as tf
from tensorflow import keras

# We define a dummy sequential model.
# This function to create a model will be used throughout the article

def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),


  return model

# Create a basic model instance
model = create_model()

# Get the dataset

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
# Create a new model using the function
model = create_model()

# Specify the checkpoint file 
# We use the str.format() for naming files according to epoch
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"

# Get the directory of checkpoint
checkpoint_dir = os.path.dirname(checkpoint_path)

# Define the batch size
batch_size = 32

# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(

# Save the weights using the `checkpoint_path` format

# Train the model with the the checkpoint callback, train_labels,

Loading from a checkpoint

In case you want to restore a checkpoint that you created you can use the model, you can use the model.load_weights() function.

Here is the syntax and an example for loading the weights.

# Syntax

model.load_weights("<path the checkpoint file(*.cpt)>")

# Example 

# Finds the latest checkpoint
latest = tf.train.latest_checkpoint(checkpoint_dir)

# Create a new model
model = create_model()

# Load the weights of the latest checkpoint

Save the weights of a trained model

A model can also be saved after the training. The process is comparatively much simpler than checkpoints during training.

To save the weights file after a model is trained, we use the Model.save_weights() function. An example for using it is as follows:

# Save the weights

# Create a new model instance
model = create_model()

# Restore the weights

Load the weights of the trained model

To load the model from a weight we can use the Model.load_weights() just like loading checkpoint weights. In fact, the weights stored as a checkpoint file.

# Restore the weights

Saving and loading an entire model

In the previous section, we saw how we can save the weights of a model. This has a certain problem to it. The model must be defined before we load the model weights to the model. Any structural difference between the actual model and the model you want to load the weights to can lead to errors.

Moreover, this method of saving weights becomes difficult when we want to use models across different platforms. For example, you want to use the model trained in python in your browser using TensorFlow JS.

In such cases, you might require to save the whole model i.e. the structure along with the weights. TensorFlow allows you to save the model using the function Here is an example of doing so.

# Save the whole model in SaveModel format'my_model')

TensorFlow also offers the users to save the model using HDF5 format. To save the model in HDF5 format just mention the filename using the hdf5 extension.

# Save the model in hdf5 format

# The .h5 extension indicates that the model is to be saved in the hdf5 extension.'my_model.h5')

Note: HDF5 was initially used by Keras before it became mainstream in TensorFlow. TensorFlow uses the SaveModel format and it is always advised to go for the recommended newer format.

You can load these saved models using the tf.keras.models.load_model(). The function automatically intercepts whether the model is saved in SaveModel format or hdf5 format. Here is an example for doing so:

# For both hdf5 format and SaveModel format use the appropriate path to the file

# SaveModel Format
loaded_model = tf.keras.models.load_model('my_model')

# HDF5 format
loaded_model = tf.keras.models.load_model('my_model.h5')


This brings us to the end of the tutorial. Hopefully, you can now save and load models in your training process. Stay tuned to learn more about deep-learning frameworks like PyTorch, TensorFlow and JAX.