链载Ai

标题: 【文档智能】轻量级级表格识别算法模型-SLANet [打印本页]

作者: 链载Ai    时间: 昨天 11:32
标题: 【文档智能】轻量级级表格识别算法模型-SLANet

前言

前面文档介绍了文档智能上多种思路及核心技术实现《【文档智能 & RAG】RAG增强之路:增强PDF解析并结构化技术路线方案及思路》,

表格识别作为文档智能的重要组成部分,面临着复杂结构和多样化格式的挑战。本文介绍的轻量级的表格识别算法模型——SLANet,旨在在保证准确率的同时提升推理速度,方便生产落地。SLANet综合了PP-LCNet作为基础网络,采用CSP-PAN进行特征融合,并引入Attention机制以实现结构与位置信息的精确解码。通过这一框架,SLANet不仅有效减少了计算资源的消耗,还增强了模型在实际应用场景中的适用性与灵活性。

PP-LCNet

PP-LCNet是一种一种轻量级的CPU卷积神经网络,在图像分类的任务上表现良好,具有很高的落地意义。PP-LCNet的准确度显著优于具有相同推理时间的先前网络结构。

模型细节

PP-LCNet系列效果

CSP-PAN

PAN结构图:相比于原始的FPN多了自下而上的特征金字塔。

CSPNet是一种处理的思想,可以和ResNet、ResNeXt和DenseNet结合。用 CSP 网络进行相邻 feature maps 之间的特征连接和融合。

CSP-PAN的引入主要有下面三个目的:

  1. 增强CNN的学习能力
  2. 减少计算量
  3. 降低内存占用

SLANet

原理:

从上图看,SLANet主要由PP-LCNet + CSP-PAN + Attention组合得到。

核心代码实现

importtorch
fromtorchimportnn
fromtorch.nnimportfunctionalasF


classSLAHead(nn.Module):
def__init__(self,in_channels=96,is_train=False)->None:
super().__init__()
self.max_text_length=500
self.hidden_size=256
self.loc_reg_num=4
self.out_channels=30
self.num_embeddings=self.out_channels
self.is_train=is_train

self.structure_attention_cell=AttentionGRUCell(in_channels,
self.hidden_size,
self.num_embeddings)

self.structure_generator=nn.Sequential(
nn.Linear(self.hidden_size,self.hidden_size),
nn.Linear(self.hidden_size,self.out_channels)
)

self.loc_generator=nn.Sequential(
nn.Linear(self.hidden_size,self.hidden_size),
nn.Linear(self.hidden_size,self.loc_reg_num)
)

defforward(self,fea):
batch_size=fea.shape[0]

#1x96x16x16→1x96x256
fea=torch.reshape(fea,[fea.shape[0],fea.shape[1],-1])

#1x256x96
fea=fea.permute(0,2,1)

#infer1x501x30
structure_preds=torch.zeros(batch_size,self.max_text_length+1,
self.num_embeddings)
#1x501x4
loc_preds=torch.zeros(batch_size,self.max_text_length+1,
self.loc_reg_num)

hidden=torch.zeros(batch_size,self.hidden_size)
pre_chars=torch.zeros(batch_size,dtype=torch.int64)

loc_step,structure_step=None,None
foriinrange(self.max_text_length+1):
hidden,structure_step,loc_step=self._decode(pre_chars,
fea,hidden)
pre_chars=structure_step.argmax(dim=1)
structure_preds[:,i,:]=structure_step
loc_preds[:,i,:]=loc_step

ifnotself.is_train:
structure_preds=F.softmax(structure_preds,dim=-1)
#structure_preds:1x501x30
#loc_preds:1x501x4
returnstructure_preds,loc_preds

def_decode(self,pre_chars,features,hidden):
emb_features=F.one_hot(pre_chars,num_classes=self.num_embeddings)
(output,hidden),alpha=self.structure_attention_cell(hidden,
features,
emb_features)
structure_step=self.structure_generator(output)
loc_step=self.loc_generator(output)
returnhidden,structure_step,loc_step


classAttentionGRUCell(nn.Module):
def__init__(self,input_size,hidden_size,num_embedding)->None:
super().__init__()

self.i2h=nn.Linear(input_size,hidden_size,bias=False)
self.h2h=nn.Linear(hidden_size,hidden_size)
self.score=nn.Linear(hidden_size,1,bias=False)

self.gru=nn.GRU(input_size=input_size+num_embedding,
hidden_size=hidden_size,)
self.hidden_size=hidden_size

defforward(self,prev_hidden,batch_H,char_onehots):
#这里实现参考论文https://arxiv.org/pdf/1704.03549.pdf
batch_H_proj=self.i2h(batch_H)
prev_hidden_proj=torch.unsqueeze(self.h2h(prev_hidden),dim=1)

res=torch.add(batch_H_proj,prev_hidden_proj)
res=F.tanh(res)
e=self.score(res)

alpha=F.softmax(e,dim=1)
alpha=alpha.permute(0,2,1)
context=torch.squeeze(torch.matmul(alpha,batch_H),dim=1)
concat_context=torch.concat([context,char_onehots],1)

cur_hidden=self.gru(concat_context,prev_hidden)
returncur_hidden,alpha


classSLALoss(nn.Module):
def__init__(self)->None:
super().__init__()
self.loss_func=nn.CrossEntropyLoss()
self.structure_weight=1.0
self.loc_weight=2.0
self.eps=1e-12

defforward(self,pred):
structure_probs=pred[0]
structure_probs=structure_probs.permute(0,2,1)
#1x30x501

#1x501
structure_target=torch.empty(1,501,dtype=torch.long).random_(30)
structure_loss=self.loss_func(structure_probs,structure_target)
structure_loss=structure_loss*self.structure_weight

loc_preds=pred[1]#1x501x4
loc_targets=torch.randn(1,501,4)
loc_target_mask=torch.randn(1,501,1)

loc_loss=F.smooth_l1_loss(loc_preds*loc_target_mask,
loc_targets*loc_target_mask,
reduction='mean')
loc_loss*=self.loc_weight
loc_loss=loc_loss/(loc_target_mask.sum()+self.eps)

total_loss=structure_loss+loc_loss
returntotal_loss






欢迎光临 链载Ai (https://www.lianzai.com/) Powered by Discuz! X3.5