让数百万台手机训练同一个模型?Google把这套框架开源了

2019 年 3 月 9 日 AI100


作者 | 琥珀

出品 | AI科技大本营(公众号id:rgznai100)


【导语】据了解,全球有 30 亿台智能手机和 70 亿台边缘设备。每天,这些电话与设备之间的交互不断产生新的数据。传统的数据分析和机器学习模式,都需要在处理数据之前集中收集数据至服务器,然后进行机器学习训练并得到模型参数,最终获得更好的产品。


但如果这些需要聚合的数据敏感且昂贵的话,那么这种中心化的数据收集手段可能就不太适用了。


去掉这一步骤,直接在生成数据的边缘设备上进行数据分析和机器学习训练呢?


近日,Google 开源了一款名为 TensorFlow Federated (TFF)的框架,可用于去中心化(decentralized)数据的机器学习及运算实验。它实现了一种称为联邦学习(Federated Learning,FL)的方法,将为开发者提供分布式机器学习,以便在没有数据离开设备的情况下,便可在多种设备上训练共享的 ML 模型。其中,通过加密方式提供多一层的隐私保护,并且设备上模型训练的权重与用于连续学习的中心模型共享。


传送门:https://www.tensorflow.org/federated/


实际上,早在 2017 年 4 月,Google AI 团队就推出了联邦学习的概念。这种被称为联邦学习的框架目前已应用在 Google 内部用于训练神经网络模型,例如智能手机中虚拟键盘的下一词预测和音乐识别搜索功能。



图注:每台手机都在本地训练模型(A);将用户更新信息聚合(B);然后形成改进的共享模型(C)。


DeepMind 研究员 Andrew Trask 随后发推称赞:“Google 已经开源了 Federated Learning……可在数以百万计的智能手机上共享模型训练!”



让我们一起来看看使用教程:


从一个著名的图像数据集 MNIST 开始。MNIST 的原始数据集为 NIST,其中包含 81 万张手写的数字,由 3600 个志愿者提供,目标是建立一个识别数字的 ML 模型。


传统手段是立即将 ML 算法应用于整个数据集。但实际上,如果数据提供者不愿意将原始数据上传到中央服务器,就无法将所有数据聚合在一起。


TFF 的优势就在于,可以先选择一个 ML 模型架构,然后输入数据进行训练,同时保持每个数据提供者的数据是独立且保存在本地。


下面显示的是通过调用 TFF 的 FL API,使用已由 GitHub 上的“Leaf”项目处理的 NIST 数据集版本来分隔每个数据提供者所写的数字:


GitHub 传送链接:https://github.com/TalwalkarLab/leaf



# Load simulation data.
source, _ = tff.simulation.datasets.emnist.load_data()
def client_data(n):
  dataset = source.create_tf_dataset_for_client(source.client_ids[n])
  return mnist.keras_dataset_from_emnist(dataset).repeat(10).batch(20)

# Wrap a Keras model for use with TFF.
def model_fn():
  return tff.learning.from_compiled_keras_model(
      mnist.create_simple_keras_model(), sample_batch)

# Simulate a few rounds of training with the selected client devices.
trainer = tff.learning.build_federated_averaging_process(model_fn)
state = trainer.initialize()
for _ in range(5):
  state, metrics = trainer.next(state, train_data)
  print (metrics.loss)


除了可调用 FL API 外,TFF 还带有一组较低级的原语(primitive),称之为 Federated Core (FC) API。这个 API 支持在去中心化的数据集上表达各种计算。


使用 FL 进行机器学习模型训练仅是第一步;其次,我们还需要对这些数据进行评估,这时就需要 FC API 了。


假设我们有一系列传感器可用于捕获温度读数,并希望无需上传数据便可计算除这些传感器上的平均温度。调用 FC 的 API,就可以表达一种新的数据类型,例如指出 tf.float32,该数据位于分布式的客户端上。


READINGS_TYPE = tff.FederatedType(tf.float32, tff.CLIENTS)


然后在该类型的数据上定义联邦平均数。


@tff.federated_computation(READINGS_TYPE)
def get_average_temperature(sensor_readings):
  return tff.federated_average(sensor_readings)


之后,TFF 就可以在去中心化的数据环境中运行。从开发者的角度来讲,FL 算法可以看做是一个普通的函数,它恰好具有驻留在不同位置(分别在各个客户端和协调服务中的)输入和输出。



例如,使用了 TFF 之后,联邦平均算法的一种变体:


参考链接:https://arxiv.org/abs/1602.05629


@tff.federated_computation(
  tff.FederatedType(DATASET_TYPE, tff.CLIENTS),
  tff.FederatedType(MODEL_TYPE, tff.SERVER, all_equal=True),
  tff.FederatedType(tf.float32, tff.SERVER, all_equal=True))
def federated_train(client_data, server_model, learning_rate):
  return tff.federated_average(
      tff.federated_map(local_train, [
          client_data,
          tff.federated_broadcast(server_model),
          tff.federated_broadcast(learning_rate)]))


目前已开放教程,可以先在模型上试验现有的 FL 算法,也可以为 TFF 库提供新的联邦数据集和模型,还可以添加新的 FL 算法实现,或者扩展现有 FL 算法的新功能。


据了解,在 FL 推出之前,Google 还推出了 TensorFlow Privacy,一个机器学习框架库,旨在让开发者更容易训练具有强大隐私保障的 AI 模型。目前二者可以集成,在差异性保护用户隐私的基础上,还能通过联邦学习(FL)技术快速训练模型。


最后附上 TF Dev Summit’19 上,TensorFlow Federated (TFF)的发布会现场视频:



参考链接:https://medium.com/tensorflow/introducing-tensorflow-federated-a4147aa20041


(本文为 AI科技大本营原创文章,转载请微信联系 1092722531


4 月13日-4 月14日,CSDN 将在北京主办“Python 开发者日( 2019 )”,汇聚十余位来自阿里巴巴IBM英伟达等国内外一线科技公司的Python技术专家,还有数百位来自各行业领域的Python开发者。目前购票通道已开启,早鸟票限量发售中,3 月15日之前可享受优惠价 299 元(售完即止)。


推荐阅读:

                         

点击“阅读原文”,查看历史精彩文章。

登录查看更多
3

相关内容

联邦学习(Federated Learning)是一种新兴的人工智能基础技术,在 2016 年由谷歌最先提出,原本用于解决安卓手机终端用户在本地更新模型的问题,其设计目标是在保障大数据交换时的信息安全、保护终端数据和个人数据隐私、保证合法合规的前提下,在多参与方或多计算结点之间开展高效率的机器学习。其中,联邦学习可使用的机器学习算法不局限于神经网络,还包括随机森林等重要算法。联邦学习有望成为下一代人工智能协同算法和协作网络的基础。
【Google】利用AUTOML实现加速感知神经网络设计
专知会员服务
29+阅读 · 2020年3月5日
【综述】7篇非常简洁近期深度学习综述论文
专知会员服务
74+阅读 · 2019年12月31日
谷歌机器学习速成课程中文版pdf
专知会员服务
145+阅读 · 2019年12月4日
年度大盘点:机器学习开源项目及框架
云栖社区
3+阅读 · 2018年12月17日
深度学习开发必备开源框架
九章算法
12+阅读 · 2018年5月30日
教程帖:深度学习模型的部署
论智
8+阅读 · 2018年1月20日
谷歌发布TensorFlowLite,用半监督跨平台快速训练ML模型!
全球人工智能
5+阅读 · 2017年11月15日
【机器学习】推荐13个机器学习框架
产业智能官
8+阅读 · 2017年9月10日
Mesh R-CNN
Arxiv
4+阅读 · 2019年6月6日
Arxiv
7+阅读 · 2018年5月23日
Arxiv
3+阅读 · 2018年3月21日
Arxiv
8+阅读 · 2018年1月25日
VIP会员
相关资讯
年度大盘点:机器学习开源项目及框架
云栖社区
3+阅读 · 2018年12月17日
深度学习开发必备开源框架
九章算法
12+阅读 · 2018年5月30日
教程帖:深度学习模型的部署
论智
8+阅读 · 2018年1月20日
谷歌发布TensorFlowLite,用半监督跨平台快速训练ML模型!
全球人工智能
5+阅读 · 2017年11月15日
【机器学习】推荐13个机器学习框架
产业智能官
8+阅读 · 2017年9月10日
相关论文
Mesh R-CNN
Arxiv
4+阅读 · 2019年6月6日
Arxiv
7+阅读 · 2018年5月23日
Arxiv
3+阅读 · 2018年3月21日
Arxiv
8+阅读 · 2018年1月25日
Top
微信扫码咨询专知VIP会员