Flatness and Generalization: Learning Multi-Index Models with Homogeneous Neural Networks
2026-06-03 • Machine Learning
Machine Learning
AI summaryⓘ
The authors study why some neural networks that fit training data perfectly still perform well on new data. Previous ideas said that 'flat' solutions (measured by a math tool called the Hessian) generalize better, but others showed that flatness can be artificially changed without affecting performance. In this paper, the authors find that if you focus on the flattest possible solutions, there is a genuine connection to good generalization, especially when data comes from certain structured models and noise is low. This helps explain when flatness really matters for understanding how neural networks learn.
FlatnessGeneralizationNon-convex optimizationNeural networksHessianInterpolatorsMulti-index modelsSingle-index modelsSymmetriesPopulation loss
Authors
Harsh Vardhan, Hossein Taheri, Arya Mazumdar
Abstract
A common heuristic used to explain the generalization of first-order gradient methods on non-convex neural networks is that "flat interpolators generalize well" (Hochreiter and Schmidhuber, 1994; Keskar et al., 2017), where flatness can be measured by the trace of the Hessian of the empirical loss. However, Dinh et al. 2017) showed that, using symmetry of the network that can change flatness while keeping the population and empirical losses unchanged, any interpolator can be made sharper or flatter. This result makes the earlier heuristic statement vacuous. In this paper, we show that for learning an unknown multi-index model with $2$-layer non-convex homogeneous neural networks, there is a connection between flatness and generalization, despite the existence of symmetries. This connection pertains to the "flattest" interpolators, i.e., the interpolators that have orderwise minimum flatness among all interpolators. First, we show that there exists a natural class of non-generalizing interpolators whose flatness cannot be made closer to the flattest possible, even using symmetries. Second, we show that for data generated by a sum of single-index models, if the approximation error and label noise are low, any flattest interpolator achieves small population loss, i.e., the flattest interpolators always generalize. This establishes a direct link between flatness and generalization which applies to a large class of activations and realistic data distributions.