Sharpness-aware minimization (SAM) has emerged as a highly effective technique for improving model generalization, but its underlying principles are not fully understood. We investigated the phenomenon known as m-sharpness, where the performance of SAM improves monotonically as the micro-batch size for computing perturbations decreases. In practice, the empirical m-sharpness effect underpins the deployment of SAM in distributed training, yet a rigorous theoretical account has remained lacking. To provide a theoretical explanation for m-sharpness, we leverage an extended Stochastic Differential Equation (SDE) framework and analyze the structure of stochastic gradient noise (SGN) to characterize the dynamics of various SAM variants, including n-SAM and m-SAM. Our findings reveal that the stochastic noise introduced during SAM perturbations inherently induces a variance-based sharpness regularization effect. Motivated by our theoretical insights, we introduce Reweighted SAM (RW-SAM), which employs sharpness-weighted sampling to mimic the generalization benefits of m-SAM while remaining parallelizable. Comprehensive experiments validate the effectiveness of our theoretical analysis and proposed method.
翻译:暂无翻译