Feed m Birds with One Scone: Accelerating Multi-task Gradient Balancing via Bi-level Optimization
2026-03-08 • Machine Learning
Machine Learning
AI summaryⓘ
The authors address the challenge of multi-task learning, where a model tries to learn several tasks at once without hurting any single task's performance. They note that previous methods like MGDA work well but are slow because they need gradients from all tasks simultaneously. To fix this, the authors propose MARIGOLD, a new approach that treats the problem as a two-level optimization and uses a special method to solve it faster. Their tests on various datasets show that this method is both efficient and effective.
multi-task learninggradient descentMGDAgradient balancingbi-level optimizationzeroth-order methodsmachine learning optimizationtask conflictalgorithm efficiency
Authors
Xuxing Chen, Yun He, Jiayi Xu, Minhui Huang, Xiaoyi Liu, Boyang Liu, Fei Tian, Xiaohan Wei, Rong Jin, Sem Park, Bo Long, Xue Feng
Abstract
In machine learning, the goal of multi-task learning (MTL) is to optimize multiple objectives together. Recent works, for example, Multiple Gradient Descent Algorithm (MGDA) and its variants, show promising results with dynamically adjusted weights for different tasks to mitigate conflicts that may potentially degrade the performance on certain tasks. Despite the empirical success of MGDA-type methods, one major limitation of such methods is their computational inefficiency, as they require access to all task gradients. In this paper we introduce MARIGOLD, a unified algorithmic framework for efficiently solving MTL problems. Our method reveals that multi-task gradient balancing methods have a hierarchical structure, in which the model training and the gradient balancing are coupled during the whole optimization process and can be viewed as a bi-level optimization problem. Moreover, we showcase that the bi-level problem can be solved efficiently by leveraging zeroth-order method. Extensive experiments on both public datasets and industrial-scale datasets demonstrate the efficiency and superiority of our method.