MemDLM: Memory-Enhanced DLM Training
2026-03-23 • Computation and Language
Computation and Language
AI summaryⓘ
The authors propose MemDLM, a method that improves Diffusion Language Models (DLMs) by making training more like how the model is used during prediction. They do this by adding a memory mechanism through a two-level optimization process, where a small set of fast-changing weights stores recent experience for each data point. This helps the model learn faster and achieve lower errors. Additionally, this memory can be used at prediction time to better handle tasks requiring understanding of long contexts and difficult retrieval challenges.
Diffusion Language ModelsAuto-Regressive ModelsBi-level OptimizationDenoisingParametric MemoryFast WeightsTraining-Inference MismatchLong-Context UnderstandingToken-Level AttentionRetrieval Tasks
Authors
Zehua Pei, Hui-Ling Zhen, Weizhe Lin, Sinno Jialin Pan, Yunhe Wang, Mingxuan Yuan, Bei Yu
Abstract
Diffusion Language Models (DLMs) offer attractive advantages over Auto-Regressive (AR) models, such as full-attention parallel decoding and flexible generation. However, they suffer from a notable train-inference mismatch: DLMs are trained with a static, single-step masked prediction objective, but deployed through a multi-step progressive denoising trajectory. We propose MemDLM (Memory-Enhanced DLM), which narrows this gap by embedding a simulated denoising process into training via Bi-level Optimization. An inner loop updates a set of fast weights, forming a Parametric Memory that captures the local trajectory experience of each sample, while an outer loop updates the base model conditioned on this memory. By offloading memorization pressure from token representations to parameters, MemDLM yields faster convergence and lower training loss. Moreover, the inner loop can be re-enabled at inference time as an adaptation step, yielding additional gains on long-context understanding. We find that, when activated at inference time, this Parametric Memory acts as an emergent in-weight retrieval mechanism, helping MemDLM further reduce token-level attention bottlenecks on challenging Needle-in-a-Haystack retrieval tasks. Code: https://github.com/JarvisPei/MemDLM.