MI分类CNN+Transformer

论文名称 A Transformer-Based Approach Combining Deep Learning Network and Spatial-T emporal Information for Raw EEG Classification
期刊 IEEE TRANSACTIONS ON NEURAL SYSTEMS AND REHABILIT A TION ENGINEERING 4.528/Q1
方法 我们提出了一种新的基于transformer模块的脑电深度学习分类结构,并分析了MI-EEG分类的模型行为。主要贡献如下:首先,我们设计了一种新的transformer模型,用于研究类脑神经机制。五类基于transformer的模型,包括空间transformer(s-Trans)、时间transformer(t-Trans),空间CNN+transformer(s-CTrans)、时间CNN+transformer(t-CTrans)和融合CNN+transformer(f-CTrans),并在Physionet EEG运动/图像数据集上系统地测试了这些模型。
结论 利用3s数据,我们的模型在两类、三类和四类分类任务中分别获得了83.31%、74.44%和64.22%的最佳精度,优于其他最先进的(SOTA)模型。接下来,我们探索了三类Pos_Embeding(PE)模块、相对位置编码、信道相关位置编码和学习位置编码。与没有位置信息的基线模型相比,嵌入位置编码的准确率提高了0.36%-2.63%,证明了Pos_Embeding方法可以提高脑电分类能力。最后,可视化了s-Trans模型中多头注意模块的权重。根据EEG电极位置绘制可视化图。我们发现,对应于感觉运动区域的transformer模块的权重显示了事件相关去同步(ERD)模式。这些数据与Mu和β带节律ERD一致。可视化结果表明,基于transformer的方法有助于理解基于脑电数据的分类任务的网络行为。

数据处理

在本研究中,我们重点研究了运动想象分类,并选择了以下内容:

  • 睁眼休息状态的基线运行(O)
  • 针对左拳(L)对右拳(R)的运动想象进行三次任务
  • 双拳对双脚运动想象的三次任务(F) 选择21项试验。每次试验持续8秒,前2秒用于休息,后4秒用于运动想象,最后2秒用于静止。使用3段(480个样本)和6段(960个样本)脑电图数据训练和测试我们的模型。我们使用3s和6s数据进行分类。3s数据包括来自运动成像周期的第一个3s数据,6s数据包括整个运动成像周期以及运动成像周期前一秒和后一秒。我们应用==Z-score归一化==来预处理EEG数据,并添加随机噪声以防止过度拟合,如以下公式所示: \(\mathrm{X}^{*}=\frac{\mathrm{X}-\mu}{\delta}+\alpha \mathrm{N}\)
符号 含义
$X$ 表示原始EEG数据
$X^{*}$ 表示预处理后的脑电数据
μ 表示数据的平均值
δ 表示标准偏差
N 表示随机噪声
α 控制随机噪声的百分比

模型结构

Img 该框架演示了原始EEG数据的端到端分类。该网络由transformer模块以及Pos_Embeding操作组成。 我们还设计了结合CNN模块和Transformer模块的方法。包括CNN是因为其良好的特征表示特性。

在实现中,我们总共构建了五个基于transformer的模型,其中两个模型仅依赖transformer,不包括CNN,

三个模型使用CNN和transformer组合的网络架构。在CNN和transformer模块之后,我们加入了一个完全连接的层。

Transformer

采用了Transformer的网络架构,该架构在自然语言处理(NLP)的翻译质量方面取得了优异的性能。与大多数竞争性神经序列转导模型一样,Transformer模块遵循encoder-decoder结构,使用stacked self-attention和point-wise全连接层。该模型将输入向量与三个不同的权重矩阵相乘,以获得查询向量(Q)、键向量(K)和值向量(V)。 Img

"缩放点积注意力"如图所示,该图计算了具有所有键的查询的点积,每个键除以√dk,并应用softmax函数以获得值的权重,如公式所示:

\(Q=W_{Q} X\) \(K=W_{K} X\) \(V=W_{V} X\)

self-attention使用==点积==模型,输出向量可以写成公式。 \(H=V \operatorname{softmax}\left(\frac{K^{\mathrm{T}} Q}{\sqrt{d_{k}}}\right)\) 之后$d_k$按比例缩放时,通过softmax函数获得每个输出通道上的注意力分数。 Img

"多头注意力"如图和公式所示: Img 在本研究中,我们采用了h=8个平行注意层(所谓的8个注意头),并将Transformer的编码器部分单独嵌入到EEG分类中。如图所示 Img

Transformer模块有两个子模块。第一个子模块包括一个多头注意力层,然后是一个规范化层。第二个子模块包括位置全连接前馈层,然后是归一化层。围绕两个子模块中的每一个子模块采用剩余连接。

Positional Embedding

在自然语言处理领域,前人在Transformer模型中嵌入了Positional Embedding,以增加位置信息。对于EEG数据,我们探索了三类位置嵌入(PE)模块:相对位置编码、信道相关位置编码和学习位置编码。相对位置嵌入方法使用正弦和余弦函数来表示相对位置编码 \(\begin{aligned} \mathrm{PE}_{(\text {pos }, 2 \mathrm{i})} &=\sin \left(\frac{\operatorname{pos}}{10000^{2 \mathrm{i} / \mathrm{d}}}\right) \\ P E_{(\operatorname{pos}, 2 \mathrm{i}+1)} &=\cos \left(\frac{\operatorname{pos}}{10000^{2 \mathrm{i} / \mathrm{d}}}\right) \end{aligned}\)

符号 含义
pos 表示通道位置
i 表示时间点
d 表示向量的维数

原理

在电极矢量的每个位置,偶数和奇数时间点的PE分别由正弦函数和余弦函数描述,i是电极矢量中节点的索引除以2。我们对位置pos1和pos2进行了相对位置编码的内积,发现随着距离的增加,两个位置之间的相关性变小。

应用

在通道相关Positional Embedding中,选择Cz电极作为中心电极,并计算其他电极与中心电极之间的余弦距离。Pcentral表示中心电极Cz的三维坐标,Pk是三维坐标的第k个位置。通过使用simk距离而不是公式中的pos来执行运算,所得矩阵是信道相关位置编码的位置编码矩阵。 \(\operatorname{sim}_{\mathrm{k}}=\frac{\mathrm{P}_{\text {central }} \cdot \mathrm{P}_{\mathrm{k}}}{\left\|\mathrm{P}_{\text {central }}\right\|\left\|\mathrm{P}_{\mathrm{k}}\right\|}\) 在学习的位置嵌入方法中,我们将相同大小的可训练矩阵嵌入到输入中,并随机初始化嵌入矩阵。在模型训练过程中,通过学习不断更新位置编码矩阵的参数。

时空Transformer

为了考虑EEG数据中时间和空间维度的相关性,我们以空间和时间两种方式排列Transformer模块的输入数据。在空间方向上(s-Trans,图3a),来自每个通道的沿时间轴的EEG数据被视为特征,Transformer模块计算不同通道之间的相关性。在时间方向上(t-Trans,图3b),将同一时间点沿通道轴的EEG数据视为特征,模型计算不同时间点之间的相关性。 Img 空间或时间方向的EEG数据嵌入位置编码,然后馈送到Transformer模块中。将Transformer模块获得的特征输入到全连接层中进行脑电分类。我们探讨了Transformer模块数量对分类结果的影响。Transformer模块的数量从1个测试到6个。当选择3个时,分类效果最佳。因此,我们在模型中包括三个Transformer模块。

CNN + Transformer

神经网络(CNN)已用于EEG数据的广义特征学习和降维,我们设计了一个融合模型,将CNN和Transformer模块结合起来。CNN实现还考虑了空间和时间表示。CNN执行特征提取,并将这些特征输入Transformer的多头注意层。 Img

在CNN+Transformer模型的空间实现中(s-CTrans,图3c),CNN模块包括两个卷积层和一个平均池层。在第一卷积层,我们使用64个1×16(通道×时间点)大小的核来提取EEG时间信息,并采用相同的填充。平均池层的池大小为1×32。第二个卷积层使用64个核,大小为1 x 15,并采用VALID填充。 Img

在CNN+Transformer模型(t-Ctrans,图3d)的时间实现中,CNN模块包括一个卷积层和一个平均池层。卷积层使用64个大小为64×1(通道×时间点)的核来提取EEG空间信息,并采用相同的填充。平均池层的池大小为1×8。在平均池层之后,我们转置了特征。 对于s-CTrans和t-Ctraans模型,CNN模块获得的特征嵌入位置编码,然后通过Transformer模块,然后通过全连接层进行EEG分类,如图3c-d所示。 Img

融合CNN+Transformer模型(f-CTrans)并行处理空间和时间信息(图3e)。 在CNN和Transformer处理之后,来自两个流的两个输出被合并。将组合特征输入到全连通层中进行脑电分类。

训练

| 名称 | 数值/方法 | | – | – | |multi-head attention层数|8| |优化器|adam optimizer+early stopping| |具有ReLU激活的position-wise全连接前馈层的参数|512| |dropout|0.3| |损失函数|交叉熵函数| |权重衰减|0.0001| |training epoch|50| |PE中的d维数|480(3s)或960(6s)| |模型验证|5倍交叉验证| Img

结果

分类精度

Img 使用3s数据,我们的模型对两类、三类和四类分类的最佳精度分别为83.31%、74.44%和64.22%。我们的结果在所有三种分类中都优于基线模型。 使用6s数据,我们模型的最佳精度分别为87.80%、78.98%和68.54%。因此,包含较长周期的EEG数据可以产生更高的分类精度。与基线模型相比,我们的方法适用于三类和四类分类。在三类和四类分类的情况下,对于3s数据,我们的f-CTrans表现最好,对于6s数据,t-CTrans模型表现最好。

因此,我们基于Transformer的分类方法具有很强的分类能力。

PE的效果

为了了解位置嵌入(PE)对分类的贡献,我们在s-Trans模型中比较了使用三种PE方法的分类结果。与没有PE的模型相比,三种PE方法具有更好的分类结果(表III)。对于3s和6s数据,不同的PE方法具有不同的精度,尽管差异不大。应该注意的是,学习的位置嵌入方法需要训练并且具有更多的训练参数。因此,在我们的模型中加入PE方法提高了分类精度。 Img

补充

| | Point wise | pairwise |list wise| | – | – | – |–| | 思想 | Pointwise排序是将训练集中的每个item看作一个样本获取rank函数,主要解决方法是把分类问题转换为单个item的分类或回归问题。 | Pairwise排序是将同一个查询中两个不同的item作为一个样本,主要思想是把rank问题转换为二值分类问题 |Listwise排序是将整个item序列看作一个样本,通过直接优化信息检索的评价方法和定义损失函数两种方法实现。| | 算法 |1、基于回归的算法;2、基于分类的算法;3、基于有序回归的算法| 基于二分类的算法 |直接基于评价指标的算法非直接基于评价指标的算法|

  • ranking 追求的是排序结果,并不要求精确打分,只要有相对打分即可。pointwise 类方法并没有考虑同一个 query 对应的 docs 间的内部依赖性。一方面,导致输入空间内的样本不是 IID 的,违反了 ML 的基本假设,另一方面,没有充分利用这种样本间的结构性。其次,当不同 query 对应不同数量的 docs 时,整体 loss 将会被对应 docs 数量大的 query 组所支配,前面说过应该每组 query 都是等价的。损失函数也没有 model 到预测排序中的位置信息。因此,损失函数可能无意的过多强调那些不重要的 docs,即那些排序在后面对用户体验影响小的 doc。
  • 1、如果人工标注包含多有序类别,那么转化成 pairwise preference 时必定会损失掉一些更细粒度的相关度标注信息。2、doc pair 的数量将是 doc 数量的二次,从而 pointwise 类方法就存在的 query 间 doc 数量的不平衡性将在 pairwise 类方法中进一步放大。3、pairwise 类方法相对 pointwise 类方法对噪声标注更敏感,即一个错误标注会引起多个 doc pair 标注错误。4、pairwise 类方法仅考虑了 doc pair 的相对位置,损失函数还是没有 model 到预测排序中的位置信息。5、pairwise 类方法也没有考虑同一个 query 对应的 doc pair 间的内部依赖性,即输入空间内的样本并不是 IID 的,违反了 ML 的基本假设,并且也没有充分利用这种样本间的结构性。
  • listwise 类相较 pointwise、pairwise 对 ranking 的 model 更自然,解决了 ranking 应该基于 query 和 position 问题。listwise 类存在的主要缺陷是:一些 ranking 算法需要基于排列来计算 loss,从而使得训练复杂度较高,如 ListNet和 BoltzRank。此外,位置信息并没有在 loss 中得到充分利用,可以考虑在 ListNet 和 ListMLE 的 loss 中引入位置折扣因子。

https://zhuanlan.zhihu.com/p/337478373