FlashOptim: Optimizers for Memory Efficient Training

2026-02-26Machine Learning

Machine LearningArtificial Intelligence
AI summary

The authors created FlashOptim, a set of improvements that cuts the memory needed to train large neural networks by over half without making the models worse. They achieved this by better handling how weights and optimizer states are stored, using special math to reduce errors from shrinking the data size. This means models with billions of parameters can be trained with less memory, making it easier for researchers with limited resources. Their tests showed no loss in quality across various tasks, including refining large language models like Llama-3.1-8B.

mixed-precision trainingneural networksoptimizer statesquantization errormemory optimizationAdamW optimizergradient quantizationmodel checkpointingcompandingLlama-3.1-8B
Authors
Jose Javier Gonzalez Ortiz, Abhay Gupta, Chris Renard, Davis Blalock
Abstract
Standard mixed-precision training of neural networks requires many bytes of accelerator memory for each model parameter. These bytes reflect not just the parameter itself, but also its gradient and one or more optimizer state variables. With each of these values typically requiring 4 bytes, training even a 7 billion parameter model can be impractical for researchers with less than 100GB of accelerator memory. We introduce FlashOptim, a suite of optimizations that reduces per-parameter memory by over 50% while preserving model quality and API compatibility. Our approach introduces two key techniques. First, we improve master weight splitting by finding and exploiting a tight bound on its quantization error. Second, we design companding functions that greatly reduce the error in 8-bit optimizer state quantization. Together with 16-bit gradients, these techniques reduce AdamW memory from 16 bytes to 7 bytes per parameter, or 5 bytes with gradient release. They also cut model checkpoint sizes by more than half. Experiments with FlashOptim applied to SGD, AdamW, and Lion show no measurable quality degradation on any task from a collection of standard vision and language benchmarks, including Llama-3.1-8B finetuning.