首页 专利交易 科技果 科技人才 科技服务 商标交易 会员权益 IP管家助手 需求市场 关于龙图腾
 /  免费注册
到顶部 到底部
清空 搜索

一种基于知识蒸馏的联邦学习性能优化方法 

申请/专利权人:北京工业大学

申请日:2024-01-12

公开(公告)日:2024-04-30

公开(公告)号:CN117952189A

主分类号:G06N3/098

分类号:G06N3/098;G06N3/096;G06N3/084;G06N3/0985;G06F18/23213

优先权:

专利状态码:在审-实质审查的生效

法律状态:2024.05.17#实质审查的生效;2024.04.30#公开

摘要:本发明涉及一种基于知识蒸馏的联邦学习性能优化方法。在物联网中,大量的设备生成和处理数据,这些设备通常具有分布式和异构性的特点,联邦学习FederatedLearning,FL可以利用这些设备的计算能力进行并行化的模型训练,加快模型的训练速度。在物联网场景中,用于模型训练的数据一般由各设备自身提供,不同设备间的数据会存在异构性。异构的数据会降低全局模型的推理精度。为了解决这个问题,本专利提出一个新方案,本方案使用软标签传递客户端的知识,并且根据软标签的相似性对多个客户端分组,每轮联邦训练时从每组中选择一个最可靠的软标签参与聚合,这样来提高模型的准确度。

主权项:1.一种基于知识蒸馏的联邦学习性能优化方法,其特征在于包括五个阶段:本地训练,预测软标签,基于软标签分组,软标签聚合,拟合软标签;设定有K个参与方,每个参与方拥有自己的私有数据集Dk,k表示第k个参与方;Dk=x,y,x是数据集中的样本,y是样本的类别;每个参与方只能访问自己的数据集,除此之外还有一个全部参与方都能访问的公共数据集D;步骤1:本地训练K个参与方在公共数据集和私有数据集Dk上使用Adam算法训练自己的本地模型Fk;公式4描述了Adam算法训练过程中参数的更新:θt+1=θt-α*mtsqrtvt+ε4θt+1表示第t+1个迭代时刻的模型参数,θt表示第t个迭代时刻的模型参数,α表示学习率,ε取值为10-8;将学习率设置为0.01;mt表示当前时刻的一阶矩估计;vt表示当前时刻的二阶矩估计;步骤2:预测软标签每个参与方在数据集DP上预测软标签logitk,DP为在公共数据集D中生成的子集,各个参与方将各自的软标签logitk传递给服务器端; yi是样本i的预测类别,j是类别的索引,C是总类别数,而αijk表示样本i与样本k之间的关联度,其中k属于类别j;通过公式5得到的yi是硬标签,为了弥补硬标签的缺陷,同时为了减少客户端和服务器端之间的通信量,使用公式6计算的软标签来表示知识; 其中Pyi=j表示样本i属于类别j的概率,zij是样本i在类别j上的得分或相似度,τ是控制标签平滑程度的温度参数,C是总类别数;公式6中的指数函数将得分转化为概率值,温度参数τ控制了概率分布的平滑程度;τ设置为0.6;步骤3:软标签分组使用K-means算法将K个软标签分成M组;1初始化:随机选择K个软标签作为初始聚类中心;2分配步骤:公式7计算每个软标签与每个聚类中心的余弦相似度,并将软标签分配到最近的聚类中心所在的组; 3更新步骤:根据分配结果,更新每个组的聚类中心为该组内所有软标签的平均值;4重复步骤2和3直到达到收敛,即软标签的分配结果不再改变或达到最大迭代次数;最终将K个软标签分成M组;步骤4:软标签聚合服务器端计算每个客户端的软标签logitkm在数据集DP上的方差,logitkm表示的是第k个客户端的软标签,同时这个客户端属于第m个组; 其中,varlogitkm是属于第m个组,第k个用户在数据集DP上的方差,N是数据集中样本的总数,C是类别的数量,xij是第i个样本在第j个类别上的软标签概率,μj是第j个类别的软标签平均值;对于每一个组,把所有客户的方差升序排列,同时按排序之后的索引将软标签排序,也就是说这时候每个组中的软标签是按升序排列的,选择每个组中方差最低的软标签参与聚合,并且聚合时加大低方差软标签的权重; A[m]包含m个按降序排列的元素,且各元素之和为1;服务器端将logit_A广播给所有参与聚合的客户端;步骤5:拟合软标签各个客户端在数据集DP继续训练并且使训练结果拟合logit_A,各个客户端依然需要在私有数据集Dk上训练,此操作记作traink; 其中,k表示客户端的索引,Fk是客户端k的模型参数,N是客户端k拥有的训练样本数目,yi是第i个训练样本在客户端k上的预测标签,该公式表示每个客户端通过最小化损失函数来调整自身的模型参数,以使其在本地数据集上的训练结果更好地拟合logit_A;具体来说,对于每个客户端,公式的目标是找到最优的模型参数使其在本地数据集上的预测结果yi与聚合软标签logit_A之间的损失最小化;不断循环步骤2到步骤5,直到达到满意的训练精度;设置理想的结果是该方法在MNIST数据集上的推理精度在90%以上,在CIFAR-10数据集上的推理精度在60%以上,在CIFAR-100数据集上推理精度在40%以上。

全文数据:

权利要求:

百度查询: 北京工业大学 一种基于知识蒸馏的联邦学习性能优化方法

免责声明
1、本报告根据公开、合法渠道获得相关数据和信息,力求客观、公正,但并不保证数据的最终完整性和准确性。
2、报告中的分析和结论仅反映本公司于发布本报告当日的职业理解,仅供参考使用,不能作为本公司承担任何法律责任的依据或者凭证。