← 返回列表

一种基于模型内部特征蒸馏的文本分类方法

申请号: CN202410064744.3
申请人: 长春大学
申请日期: 2024/1/17

摘要文本

本发明提出一种基于模型内部特征蒸馏的文本分类方法,属于自然语言处理领域;包括:首先对微博文本数据集进行预处理,使用tokenizer将文本转换为模型所需的特征;其次将特征分别传入学生模型和教师模型提取预测结果;然后将教师模型中内部的特征蒸馏出来作为软标签,与学生模型中内部的特征计算损失loss1,同时利用学生模型和教师模型的预测结果计算损失loss3,利用学生模型预测结果与学生模型的标签计算损失loss2;最后将得到的三个损失求和,在网路中进行反向传播,优化网络。本发明在文本分类任务中,压缩了模型的内存,提升了模型的性能,更好的权衡了模型大小和模型性能。 百度搜索马 克 数 据 网

专利详细信息

项目 内容
专利名称 一种基于模型内部特征蒸馏的文本分类方法
专利类型 发明申请
申请号 CN202410064744.3
申请日 2024/1/17
公告号 CN117807235A
公开日 2024/4/2
IPC主分类号 G06F16/35
权利人 长春大学
发明人 王绍强; 靳晓娇; 申向峰; 戴银飞; 王艳柏; 刘玉宝
地址 吉林省长春市朝阳区卫星路6543号

专利主权项内容

1.一种基于模型内部特征蒸馏的文本分类方法,其特征在于,包括以下步骤:步骤一、获取微博文本数据集,将微博文本数据集进行划分和预处理,其中,将获取的微博文本数据集划分出训练集和测试集,并使用tokenizer对训练集和测试集进行数据预处理,将文本转换为模型输入的包含特征向量的可识别字典;步骤二、基于预训练模型,构建教师模型和学生模型;构建教师模型,教师模型基于BERT预训练模型,主干网络利用GRU模型结合双向Bi-LSTM模型和Transformer_Attention模型搭建成带有三条路径的教师模型,三条路径产生的特征进行拼接和全连接后,得到最终的预测结果;搭建学生模型,学生模型基于RoBERTa预训练模型,主干网络仅使用GRU模型搭建成带有一条路径的学生模型;步骤三、利用步骤二中的教师模型和学生模型进行知识蒸馏;在训练过程中选择相应的损失函数计算模型损失,并保存最优模型;其中,相应的损失函数为torch.nn.functional模型库中的交叉熵CE函数、KL散度函数和均方误差MSE函数;将教师模型中内部的特征蒸馏出来作为软标签,与学生模型中内部的特征计算损失,得到损失值loss1;利用学生模型预测结果与学生模型的标签计算损失,得到损失值loss2;利用学生模型和教师模型的预测结果计算损失,得到损失值loss3;最后将三个损失相加得到模型的最终损失值Loss;通过反向传播优化参数,训练该模型,通过比对损失值的大小决定是否保存训练得到的模型及参数,训练过程中只保存损失值最小的模型结构和模型参数。