Keras 3/TF vs. PyTorch – small model performance tests on a Nvidia 4060 TI

There are many PROs and CONs regarding the choice of a Machine Learning [ML] framework for private studies on a Linux Workstations. Two mainly used frameworks are PyTorch and a Keras/Tensorflow combination. One aspect for productive work with ML models certainly is performance. And as I personally do not have TPUs or other advanced chips available, but just a consumer Nvidia 4060 TI graphics card, performance and optimal GPU usage are of major interest – even for the training of relatively small models.

With this post I just want to point out that the question of performance advantages of some framework on a CUDA controlled graphics card can not be answered in a unique way. Even for small neural network [NN] models the performance may depend on a variety of relevant settings, on jit-/xla-compilation and the chosen precision level of your training or inference runs.

Performance impressions during the training of some small CNN model

I have been used to the Keras/Tensorflow combination for almost 5 years now. Due to current developments on the LLM side, I have changed my personal politics a bit and also started to use PyTorch. To get used to PyTorch, I first tried to reproduce results of small, easy to create CNN networks – with standard float32 precision.

To train a model set up with torch.nn-layers, I used a standard torch training loop. For Keras I used the standard model.compile and model.fit functionalities. Data transfer to the GPU was handled by framework specific “datasets” used for pipeline building.

With already prepared tensors, I got the impression that PyTorch was significantly faster. Without jit-compilation even substantially faster. Then I changed to “mixed precision” – and, somewhat confusingly, my impression changed. Now, Keras3 with the Tensorflow 2 [TF2] backend seemed to be much faster. When looking closer at the calculations and respective relevant settings I also found a multitude of adjustable parameters specific for each framework.

Performance tests for a small CNN model

So, regarding small neural network models and a Nvidia 4060 TI, performance appeared to depend … I wanted to find out on what. Meanwhile, I have done some elementary, but systematic tests and published respective results in my sister-blog “machine-learning@anracom.com”. The model I used for my reproducible tests had 536,010 parameters, only. So, with small batch sizes of e.g. MNIST data you would not challenge the 4060 TI. However, for batch sizes around 128 or 256 and more than 20 epochs one can generate a GPU workload suitable for performance tests.

For details regarding PyTorch see e.g. the post

PyTorch / datasets / dataloader / data transfer to GPU – III – prepared tensor datasets and preloading to GPU.

It gives you an overview how you can tune PyTorch performance by using tensor datasets of prepared tensors (pre-loaded to the GPU) – together with adapting the batch size of your data.

A performance comparison between PyTorch and Keras3/TF2 and a discussion of a variety of performance relevant parameters and framework-specific settings can be found here:

Performance of PyTorch vs. Keras 3 with tensorflow/torch backends for a small NN-model on a Nvidia 4060 TI – I – Torch vs. Keras3/TF2 and relevant parameters

In the latter post, I in addition studied the impacts of parameters, framework specific datasets/dataloaders, jit-compilation and mixed precision on performance – both for PyTorch and Keras3/TF2.

The named posts also give you information on the versions of relevant SW, including CUDA/cudnn.

Summary of performance results and major dependencies

My impressions regarding performance differences between PyTorch and Keras3/TF2 were verified by measurements during my test runs. Below, I just summarize the main results. For more details see the posts named above.

  • PyTorch: Regarding optimal performance with PyTorch use tensor-datasets built upon prepared tensors, whenever possible. For tensor-datasets set the parameter “num_workers” of a respective dataloader to num_workers=0.
  • PyTorch with float32 precision: PyTorch provided an excellent performance for float32 precision and tensor datasets on my CUDA device (i.e. on the Nvidia graphics card) – even without jit-compilation. The turnaround times of training runs with PyTorch were significantly better than those with Keras3/TF2 (with/without jit) for float32 – by factors between 1.5 and 1.9. The GPU load created by PyTorch with optimal settings went up to a 100% – which has PROs and CONs.
  • PyTorch: jit-compilation and “mixed-precision” (for the forward pass) together have a performance-improving impact of around 20%, which was less than for Keras3/TF2. For jit-compilation the mode-parameter should be set to mode=”reduce-overhead” or mode=”max-autotune”.
  • PyTorch: The present support of XLA-compilation is questionable for a 4060 TI – to say it mildly; see the link above for more information. First preliminary tests showed a decline in performance for torch-xla version 2.5. There seems to be much space for improvements.

Now let us look at Keras3/TF2:

  • Keras3/TF2: The use of tf.data datasets helps, but the performance advantage over just providing tf tensors or even Numpy arrays to model.fit() is minor.
  • Keras3/TF2: However, adjusting the parameter “steps_per_execution” in the standard “model.compile ()” statement is of major importance for optimal performance. jit-compilation should also always be activated by setting the respective parameter jit_compile=True in the model.compile function. The compilation automatically uses XLA-capabilities on CUDA devices.
  • Keras3/TF2 with float32 precision:Strangely, the GPU load never went above 90% – not even with jit-compilation and autotune parameters for the tf.data datasets. This reflects the overall lower performance in comparison to PyTorch for float32 precision.
  • Keras3/TF2 with “mixed precision”: Keras/TF2 excels for mixed precision combined with jit-compilation. For my tests with mixed precision the performance of Keras3/PyTorch was found to be up to a factor of 1.6 better than the performance of PyTorch (the latter with optimal settings). Even the energy consumption by Keras3/TF2 for mixed precision was a bit lower than during the best mixed precision run of PyTorch.

Both frameworks showed a significant reduction of GPU load and the respective energy consumption for “mixed precision”. Which had to be expected for a small neural network model.

The message of the results listed above is mixed. Note that similar results for bigger models on other GPUs or TPUs may look differently.

My personal guess, so far, is that if your model needs float32 precision, you should try a run with PyTorch. If your model does not suffer from mixed precision stick to the Keras3/TF2 combination for the time being.