从化学结构式识别竞赛看小模型的前景
化学结构式识别竞赛(Competition on Recognition of Chemical Structures,CROCS)作为2024年文档分析与识别国际会议(International Conference on Document Analysis and Recognition,ICDAR)的一部分,其竞赛报告已经公开。作为本届的其中一个参赛者,我开发的超轻量系统仅以0.61个百分点之差不敌知网而领先季军高达8个百分点。值得注意的是,知网集成了多个参数量高达约350M的模型,而我只用了4个参数量约只有1.3M的模型就取得了相当有竞争力的成绩。可见,虽然采用大模型是一种趋势,但小模型的潜力并未充分发掘。
竞赛结果
竞赛结果已经在竞赛平台和竞赛报告公布。各参赛队伍的成绩如下:
队伍 | 准确率 | 结构准确率 |
知网 | 70.66% | 80.12% |
我们 | 70.05% | 79.44% |
内蒙古大学 | 62.03% | 72.72% |
华中科技大学 | 56.76% | 67.56% |
基辅国立大学 | 50.11% | 67.16% |
基线DenseWAP | 29.98% | 38.14% |
可见,我们的成绩已经非常接近知网的冠军系统,而大幅抛离季军的系统。
大模型在化学结构式方面并未展示出明显优势
本次比赛中,冠军中国知网使用了基于transformer的Nougat大模型,单个模型参数高达350M(3.5亿)。他们还使用了基于投票的模型集成策略,虽然没有说明模型数量,但有理由相信他们至少集成了三个模型,总参数量至少1B(10.5亿)。相反,我们只集成了4个1.3M的小模型,总参数量不过约5.2M,还不到知网的200分之1,但仍然取得了相当有竞争力的成绩。只要再集成多2个模型,甚至有希望超越知网的那次提交,而参数量仍然比它小很多(不过,我怀疑他们也收藏了更多的模型没有用上或隐藏了最优成绩,和大公司拼算力没什么前途)。这说明,基于CNN和RNN的小模型目前仍然有能力在这领域与基于transformer的大模型在准确性上竞争,而小模型具备固有的低推理成本优势,在计算量、内存占用和模型大小上都更轻量。
按照目前人们广泛信奉的scaling law,通过扩大模型规模和数据量可以稳定地提高的模型的能力,即使边际效益一直在递减。在这个竞赛中,数据量可能成为了令大模型未能辗压小模型的瓶颈,约5万张真实手写图片对于3.5亿参数的模型来说也许不太足够,合成的逼真手写体数据又欠奉,这使得大模型的强项未能充分发挥,反而容易过拟合。通过网络架构和损失函数的设计可以向专用小模型注入先验知识,但通用的大模型则只能从有限的样本中自己学。比如说,CNN大致保证了图片中一个对象在平移后仍然能认出来,覆盖向量能让模型记住图像中哪个区域已经转换从而减低遗漏/重复识别的情况。由此可见,对于收集和标注真实数据成本高昂的任务(例如这次竞赛中带位置标注的化学结构式),小模型仍然是有发展潜力的。
小模型的优化技巧
基于对数据/模型的理解进行针对性实验,往往可以提高小模型的参数效率。虽然我们与季军队伍都基于在光学数学公式识别领域颇为影响力的DenseWAP模型,但我们的系统不但在准确性上显著优于他们的,而且总参数量甚至比原版的DenseWAP更低一些。事实上,DenseWAP的论文出版于2018年,此后不管是数学公式识别还是更一般的计算机视觉都有了长足的进步。以下简单介绍如何把部分这些成果整合到DenseWAP中,更详细的实验结果见技术报告。
数据处理
调整标签表示和分词方式可以降低序列生成模型的学习难度。化学结构式可以用不同的方式表示为字符串,但不同的格式对图像到序列模型来说的学习难度并不相同,承办方自己的论文就已经注意到这一点。他们已经对数据进行了一些规范化,以减低遍历顺序造成表示的不惟一性,但我们还做了更多。首先,我们对键角和键长进行离散化以把回归问题化为分类问题,具体地,我们把键角舍入到15度的倍数并把360度与0度等同,而忽略键长。其次,我们用两个token来分别表示键角和键的类型(基线系统把两者合起来当作一个token)以减低类别分布的不均匀性和增加每个类别的出现次数,这也使得预测在训练中从未见过的键角类型组合变得可行。即使对于只需要输出SMILES代码而不需要位置信息的场合,用化学信息工具包如CDK或RDkit对SMILES代码进行规范化也是有益的。
数据增强现已是提高模型泛化能力是标准方法。在训练过程中,我们对图片进行随机的仿射变换、模糊、锐化、亮度调整、对比度调整和噪声增加等操作,从而提高模型对不同拍照条件的鲁棒性。知网也用了类似的策略,而内蒙古大学则只用了缩放。粗略来说,卷积神经网络会学到平移不变的特征,但往往学不到缩放不变的特征,故随机缩放这种数据增强策略对基于图像的数学公式/化学结构式识别特别有用。
合成数据在我们在CROHME 2023中致胜的关键。不过,由于官方使用的数据格式是他们自定义的,缺乏现成的高质量解析和渲染工具,而我不能在这个竞赛上投入太多时间去开发,我们这次并未使用合成数据。值得一提的是,知网合成了一批印刷体图片。由于他们使用了参数量更大的模型,应该需要更多训练数据,而官方的训练集约只有5万张手写图片,故他们可能更迫切地需要合成数据。在未来,合成逼真的手写体图片应该会是化学结构式识别的一个方向。
骨干网络的现代化
基于编码器-解码器架构的DenseWAP使用DenseNet作为编码器。众所周知,骨干网络对计算机视觉模型的效果影响较大,而DenseNet在ImageNet等评测上早已经被ResNet的各种变种超过,故这个编码器可能有优化空间。不过,既然DenseNet在图像转序列任务上是行之有效的,擅长利用不同抽象层次的特征,我们决定参考其它骨干网络的设计去对DenseNet进行现代化,而不是冒险把它整个换掉。以下是我们对它的主要修改:
- 使用深度可分离卷积。自MobileNet系列开始,人们就开始广泛利用深度可分离卷积来大幅减低参数量和计算量,EfficientNet和ConvNeXt等基于CNN的主干网络都是例子。特别地,我们把DenseNet中除首个卷积层外的所有卷积层从全卷积换成了逐点卷积和深度可分离卷积的组合,就像MobileNet v2中那样。深度可分离卷积同时为增大卷积核创造了条件。
- 使用更大的卷积核。原来的DenseNet除首层外其它卷积核的窗口大小都是3x3,但人们发现使用更大的卷积核往往可以通过扩大感受野来提高特征提取能力,MobileNet v3、EfficientNet和ConvNeXt等都使用了5x5甚至更大的卷积核。因此,我们使用5x5的卷积核。
- 使用实例正规化(Instance Normalization)。正规化是使训练深层神经网络变得可行的关键技术,许多骨干网络就使用了批正规化(Batch Normalization)。批正则化的一个好处是在推理时可以折叠到前面的卷积/全连接层中,因而在推理阶段往往是没有开销的。不过,批正则化在训练阶段和推理阶段的计算不一致,这在批大小较小时对泛化能力不利。ConvNeXt等就把批正规化换成了层正规化(Layer Normalization)。由于我们基于DenseNet,网络中许多卷积层的输出通道数低,层正规化会丢失的信息还是有点多,所以我们把批正规化换成了实例正规化。
- 使用较小的通道数。由于一般认为神经网络表达能力更多来源于深度而非宽度,我们通过进一步缩小网络宽度来压缩参数量和计算量。
损失函数
作为一个自回归的模型,交叉熵固然可以监督每步解码所作的分类。但是,我们还可以做更多。
- 计数损失函数。近年在数学公式识别中有人采用了弱监督策略以引导编码器学习对任务而言更有区分力的特征。我们在训练时加上一个根据编码器提取出来的特征预测部分可视token出现次数的分支,并与原模型一起进行多任务学习。
- 放宽的损失函数。由于键角进行了量化,在两个区间边界附近的键角对应的预测会不稳定,而轻微的键角误差并不会影响结果正误的判断(图同构即可),这会导致交叉熵纠结于这种无关重要的细节,使模型学不到更重要的东西。因此,在训练的末段,我们调整了损失函数以削弱这情况。
解码方式
DenseWAP作为一个图片到序列模型,它是通过自回归方式逐个token预测的。我们在解码阶段使用了以下技巧:
- 模型集成。通过从不同的初始化开始训练多个模型,在预测阶段可以对不同模型的输出进行平均,从而得到更可靠的置信度估计。这在模式识别相关的竞赛中应该算常规操作。
- 集束搜索。通过在解码过程中维护多个可能结果的前缀,可以利用后面token的信息纠正前面的错误,从而更充分地利用上下文信息。这是一种常规的解码策略。
- 形式语法约束。由于不是所有token的排列都是有意义,而模型从有限的样本往往不能学会输出字符串应该符合的规则,导致不时输出无效的输出。为了确保输出结果可以被下游应用(如评测工具、渲染工具和化学信息系统)顺利解析,我们手写了一个上下文无关语法去描述合法的输出字符串集合,并强制输出结果符合该语法。值得一提的是,LL(1)语法基于预测分析表的解析方法可以无缝集成到集束搜索的过程中。
化学结构式识别的未来
对于手写化学结构式这个任务,由于收集和标注高质量数据的成本高,缓解数据不足可能是进一步提高准确性的关键。一个方向是设法生成逼真的手写体化学结构式图片。由于印刷体图片容易生成,另一个方向是使用合适的损失函数来引导编码器对相同内容的手写体和印刷体图片提取相近的特征,再配合大量印刷体图片。当然,设计更合适的网络结构和损失函数也可能可以提高小样本泛化能力。比如说,基于目标检测的编码器和基于树/图的解码器也许可以更直接地建模化学结构式识别问题。
由于化学结构式识别不太可能做到高度准确,如何利用不准确的识别结果就是应用中必须面对的问题。一种可能思路是对置信度进行校准,把识别结果分成可靠与不可靠两类,前者高度准确以至可以用于自动化流程,而后者则需要人手检验并在需要时纠正。校对工具也许可以通过标出对识别结果把握较低的区域(包括输入和输出)来让人更容易发现错误,并提供不同粒度的候选来协助人纠正错误。虽然当前化学结构式识别的研究集中在图像方面,但基于轨迹的联机手写识别也有应用价值,而且交互式的输入方式更方便纠错。
和数学公式类似,目前把化学结构式输入计算机还有一些不方便的地方,而端侧识别模型可以提供一种自然的辅助输入方式。