Dataset distillation seeks to condense datasets into smaller but highly representative synthetic samples. While diffusion models now lead all generative benchmarks, current distillation methods avoid them and rely instead on GANs or autoencoders, or, at best, sampling from a fixed diffusion prior. This trend arises because naive backpropagation through the long denoising chain leads to vanishing gradients, which prevents effective synthetic sample optimization. To address this limitation, we introduce Latent Dataset Distillation with Diffusion Models (LD3M), the first method to learn gradient-based distilled latents and class embeddings end-to-end through a pre-trained latent diffusion model. A linearly decaying skip connection, injected from the initial noisy state into every reverse step, preserves the gradient signal across dozens of timesteps without requiring diffusion weight fine-tuning. Across multiple ImageNet subsets at 128x128 and 256x256, LD3M improves downstream accuracy by up to 4.8 percentage points (1 IPC) and 4.2 points (10 IPC) over the prior state-of-the-art. The code for LD3M is provided at https://github.com/Brian-Moser/prune_and_distill.
翻译:暂无翻译