Federated learning-assisted edge intelligence enables privacy protection in modern intelligent services. However, not Independent and Identically Distributed (non-IID) distribution among edge clients can impair the local model performance. The existing single prototype-based strategy represents a sample by using the mean of the feature space. However, feature spaces are usually not clustered, and a single prototype may not represent a sample well. Motivated by this, this paper proposes a multi-prototype federated contrastive learning approach (MP-FedCL) which demonstrates the effectiveness of using a multi-prototype strategy over a single-prototype under non-IID settings, including both label and feature skewness. Specifically, a multi-prototype computation strategy based on \textit{k-means} is first proposed to capture different embedding representations for each class space, using multiple prototypes ($k$ centroids) to represent a class in the embedding space. In each global round, the computed multiple prototypes and their respective model parameters are sent to the edge server for aggregation into a global prototype pool, which is then sent back to all clients to guide their local training. Finally, local training for each client minimizes their own supervised learning tasks and learns from shared prototypes in the global prototype pool through supervised contrastive learning, which encourages them to learn knowledge related to their own class from others and reduces the absorption of unrelated knowledge in each global iteration. Experimental results on MNIST, Digit-5, Office-10, and DomainNet show that our method outperforms multiple baselines, with an average test accuracy improvement of about 4.6\% and 10.4\% under feature and label non-IID distributions, respectively.
翻译:----
联邦学习辅助下的边缘智能使现代智能服务具有隐私保护功能。然而,边缘客户端之间的非独立同分布(non-IID)分布可能会影响本地模型性能。现有的单原型策略通过使用特征空间的均值来表示样本。然而,特征空间通常没有聚类,并且单个原型可能无法很好地代表样本。在此基础上,本文提出了一种多类原型联邦对比学习方法(MP-FedCL),在非IID设置中展示了采用多原型策略优于采用单一原型的有效性,包括标签和特征偏斜。具体而言,首先提出了一种基于“Kmeans”的多原型计算策略,用于捕获每个类空间不同的嵌入表示,在嵌入空间使用多个原型($k$个中心点)来表示一个类别。在每个全局迭代轮次中,将计算的多个原型及其相应的模型参数发送到边缘服务器以聚合为一个全局原型池,然后发送回所有客户端以指导其本地训练。 最后,每个客户端的本地训练减少其自己的监督学习任务,并通过监督对比学习从全局原型池中共享的原型中学习,这鼓励他们从其他人学习与自己的类别相关的知识,并减少每个全局迭代中吸收无关知识的程度。在MNIST,Digit-5,Office-10和DomainNet上的实验结果表明,我们的方法优于多个基线方法,无论是在特征非IIDDistribution,还是标签非IID分布下,平均测试准确率提高了约4.6\%和10.4\%。