MXNorm: Reusing MXFP block scales for efficient tensor normalisation

2026-03-13Machine Learning

Machine LearningArtificial IntelligenceNeural and Evolutionary Computing
AI summary

The authors address a slowdown in deep learning caused by normalization steps running in higher precision than matrix multiplications. They introduce MXNorm, a new method that uses scale values from a low-precision number format (MXFP8) to estimate normalization with much less data. Their tests on Llama 3 models show almost no loss in accuracy while speeding up training by up to 2.4 times for normalization operations. This leads to small but meaningful improvements in overall model training speed. The work helps balance precision and efficiency in deep learning computations.

matrix multiplicationlow-precision formatsRMSNormnormalizationMXNormMXFP8Llama 3deep learning accelerationtorch.compilereduction operations
Authors
Callum McLean, Luke Y. Prince, Alexandre Payot, Paul Balança, Carlo Luschi
Abstract
Matrix multiplication performance has long been the major bottleneck to scaling deep learning workloads, which has stimulated the design of new accelerators that use increasingly low-precision number formats. However, improvements in matrix multiplication performance have far outstripped improvements in performance on reductions and elementwise computations, which are still being performed in higher precision. In this work, we propose MXNorm, a drop-in replacement for RMSNorm that estimates the RMS using only the block scales calculated as part of the MXFP8 cast and enables a 32x decrease in the size of reduction needed for normalization. We validate our approximation method on pre-training of Llama 3 models of 125M, 1B and 8B parameters, finding minimal loss of training accuracy compared to a baseline using RMSNorm with MXFP8 matmuls. We also show practical kernel speedups using only torch.compile of up to 2.4x for MXNorm over RMSNorm, corresponding to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4.