Masked diffusion models (MDM) are powerful generative models for discrete data that generate samples by progressively unmasking tokens in a sequence. Each token can take one of two states: masked or unmasked. We observe that token sequences often remain unchanged between consecutive sampling steps; consequently, the model repeatedly processes identical inputs, leading to redundant computation. To address this inefficiency, we propose the Partial masking scheme (Prime), which augments MDM by allowing tokens to take intermediate states interpolated between the masked and unmasked states. This design enables the model to make predictions based on partially observed token information, and facilitates a fine-grained denoising process. We derive a variational training objective and introduce a simple architectural design to accommodate intermediate-state inputs. Our method demonstrates superior performance across a diverse set of generative modeling tasks. On text data, it achieves a perplexity of 15.36 on OpenWebText, outperforming previous MDM (21.52), autoregressive models (17.54), and their hybrid variants (17.58), without relying on an autoregressive formulation. On image data, it attains competitive FID scores of 3.26 on CIFAR-10 and 6.98 on ImageNet-32, comparable to leading continuous generative models.
翻译:掩码扩散模型(MDM)是用于离散数据的强大生成模型,其通过逐步解除序列中标记的掩码状态来生成样本。每个标记可处于两种状态之一:掩码或非掩码。我们观察到,在连续采样步骤之间,标记序列往往保持不变;因此,模型重复处理相同的输入,导致计算冗余。为解决这一低效问题,我们提出部分掩码方案(Prime),该方案通过允许标记处于掩码与非掩码状态之间的中间状态来增强MDM。这一设计使模型能够基于部分观测到的标记信息进行预测,并促进细粒度的去噪过程。我们推导了变分训练目标,并引入了一种简单的架构设计以适应中间状态输入。我们的方法在多种生成建模任务中展现出卓越性能。在文本数据上,其在OpenWebText上实现了15.36的困惑度,优于先前的MDM(21.52)、自回归模型(17.54)及其混合变体(17.58),且无需依赖自回归公式。在图像数据上,其在CIFAR-10上获得了3.26的FID分数,在ImageNet-32上获得了6.98的FID分数,与领先的连续生成模型性能相当。