Skip to content

Latest commit

 

History

History
57 lines (30 loc) · 3.21 KB

半监督入门思想之伪标签.md

File metadata and controls

57 lines (30 loc) · 3.21 KB

伪标签,是啥?

今天分享的论文是 [Pseudo-Label]("The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks")

从这个论文,主要是解决三个知识点:

  1. 什么是伪标签
  2. 怎么使用伪标签
  3. 伪标签为啥有用

伪标签

先说第一个问题,假设我们现在有一个文本分类模型(先不用管分类模型是怎么来的以及怎么训练的),以及大量的无标注数据。

我们现在使用文本分类模型对无标注数据进行预测,挑选softmax之后概率最大的那个类别为当前无标注数据对应的标签。

因为是无标注数据而且我们模型准确不可能是百分之百,从而导致预测的这个标签我们并不清楚是不是精准,所以我们称之为"伪标签"。

怎么使用伪标签

“伪标签”可以帮助模型学习到无标注数据中隐藏的信息。

我们先来看模型的损失函数是如何定义的:

损失函数

公式的前半部分针对的是标签数据的损失。我们重点来看后半部分伪标签的损失函数。

$C$ 是类别数目,$n^{,}$ 是batch数据中伪标签(无标注)数据的数量大小。$y^{,m}$ 是无标注数据的伪标签,$f^{,m}$是无标注数据的输出。$\alpha(t)$是未标注数据的权重,更新如下:

伪标签权重更新

这个更新公式值得看看,从这里可以看到,在$T_{1}$ steps之前,只是在训练数据上进行训练。随着模型的训练,无标注数据的损失函数权重在慢慢的增加。

简单来说,就是模型现在标注数据上进行训练,到一定steps之后,开始使用无标签数据的损失函数。

伪标签为啥有用

其实,从上面这个损失函数,最好奇的一点就是为什么我加了后半部分的无标签数据的损失之后(也就是在训练的时候使用无标签数据的伪标签计算损失之后),模型的表现会比只是使用标签数据要好。

损失函数的第二项,利用了熵最小化的思想。

从形式上来看,它的这个损失是在强迫模型在无标签数据上的输出更加的集中,逼近其中的一个类别,从而使得伪标签数据的熵最小。

在这个过程中,什么时候加入对伪标签的考量就很重要,因为如果太早的话,模型在训练数据上训练的并不是很好,那么模型在预测数据上的输出置信度其实就很低,误差会慢慢积累变大。

所以$\alpha(t)$是一个很重要的部分。

总结

伪标签在我的理解中,就是在模型已经训练的还可以的时候,对无标签数据进行预测,我们通过损失函数,让无标签数据逼近其中某一类(其实本质也是在做GT的文本分类)

想一下,Bert在小样本上进行finetuning之后,我们也可以把它放在无标签数据上直接预测。

由于Bert强大的能力,这样预测出来的标签置信度是很高的,我们一般可以直接拿这个结果作为冷启动的一部分。

伪标签数据半监督入门的思想,之后有时间会慢慢深入的分享几个论文。