← 返回列表

面向非独立同分布数据的联邦知识蒸馏方法及装置

申请号: CN202311714820.2
申请人: 合肥高维数据技术有限公司
更新日期: 2026-03-09

专利详细信息

项目 内容
专利名称 面向非独立同分布数据的联邦知识蒸馏方法及装置
专利类型 发明授权
申请号 CN202311714820.2
申请日 2023/12/14
公告号 CN117408330B
公开日 2024/3/15
IPC主分类号 G06N3/096
权利人 合肥高维数据技术有限公司
发明人 田辉; 王欢; 郭玉刚; 张志翔
地址 安徽省合肥市高新区望江西路900号中安创谷科技园一期A1栋21楼

摘要文本

本申请涉及一种面向非独立同分布数据的联邦知识蒸馏方法及装置,其包括根据公共数据集进行随机采样,获取辅助数据集;基于预设的优化函数以及辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型;将生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入生成网络模型,得到生成网络数据;控制客户端基于预设的数据融合算法、生成网络数据以及预设的本地数据进行数据融合,获取融合数据;控制客户端根据预设的局部模型蒸馏算法以及融合数据对深度学习模型进行优化训练,得到全局模型,本申请通过生成网络模型和局部模型蒸馏算法对客户端的深度学习模型进行优化,减少深度学习模型的优化目标与全局优化目标的偏差。 来自:

专利主权项内容

1.一种面向非独立同分布数据的联邦知识蒸馏方法,其特征在于,所述方法包括:根据预设的公共数据集进行随机采样,获取辅助数据集;基于预设的优化函数以及所述辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型;将所述生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入所述生成网络模型,得到生成网络数据;控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取融合数据;控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型;其中,所述优化函数至少包括以下一个算法或多个算法相加的组合:对抗目标损失函数、互信息平滑损失函数和相似度惩罚损失函数;所述对抗目标损失函数的计算公式为:
;其中,为所述辅助数据集中的数据样本,/>为所述噪声向量,/>为所述生成网络,和/>分别代表所述生成网络/>和所述鉴别网络/>的模型参数;所述互信息平滑损失函数的计算公式为:
;其中,代表一次批处理过程中所述噪声向量/>的数量;所述相似度惩罚损失函数的计算公式为:
;其中,和/>代表重复采样过程中不同的噪声向量;其中,所述控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取所述融合数据,包括:基于所述生成网络模型生成的所述生成网络数据/>和客户端的所述本地数据/>通过所述数据融合算法进行融合,得到所述融合数据/>;其中,所述数据融合算法的计算公式为:


;其中,为基于随迭代次数从最小值0增加到最大值0.5的动量参数,/>为样本的伪标签,/>和/>为合成后的数据样本和标签;其中,所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型,包括:计算所述生成网络数据与所述本地数据之间的数量比例;控制客户端基于所述局部模型蒸馏算法、所述数量比例以及所述融合数据对所述深度学习模型进行优化训练,得到所述全局模型;其中,所述局部模型蒸馏算法的计算公式为:
;其中,其中为所述本地数据的样本数量,/>为所述生成网络数据的样本数量,/>是代表客户端本地的深度学习模型/>在所述生成网络数据/>和所述融合数据/>之间Kullback-Leibler距离,/>为用于调整知识蒸馏强度的参数,/>为所述生成网络数据中标签为/>的样本数量,/>则代表归一化指数函数;其中,在所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型之后,还包括:接收全体客户端深度学习模型的模型参数;基于每个客户端的所述模型参数通过可学习参数进行加权处理,得到集成模型;基于所述生成网络模型批量生成的所述生成网络数据,得到虚拟数据集;基于全局聚合蒸馏算法和集成模型,通过解耦所述生成网络数据中的类别信息对全局模型进行微调,得到全局微调模型;将所述全局微调模型重新分发给各个客户端,控制每个客户端根据所述局部模型蒸馏算法以及所述融合数据、所述全局聚合蒸馏算法和所述集成模型对所述全局微调模型进行优化训练,直至所述全局微调模型收敛或者达到指定精度;其中,所述集成模型的计算公式为:
&其中,是一个可学习参数并处于0到1之间,/>则是用于控制权重参数正则化的程度,/>代表客户端上的所述模型参数;所述全局聚合蒸馏算法的定义如下:
;其中代表所述全局模型,/>代表所述集成模型,/>为所述虚拟数据集中的数据样本。