Advertisement

记录使用pytorch训练crnn

阅读量:

工程来源:

https://github.com/WenmuZhou/PytorchOCR?tab=readme-ov-file#train

在基本数据准备以及配置方面与paddleOCR具有相似性;对使用过程中遇到的问题进行做记录。

1.环境

我使用的是:172.31.50.201:5000/algorithm/pytorch-1.11.0-cuda11.3-cudnn8-devel-arcface:v2

然后陆续按照要求安装了库:

复制代码
 #pip install imgaug -i https://pypi.tuna.tsinghua.edu.cn/simple

    
 #pip install pyclipper -i https://pypi.tuna.tsinghua.edu.cn/simple
    
 #pip install lmdb -i https://pypi.tuna.tsinghua.edu.cn/simple
    
 #pip install rapidfuzz -i https://pypi.tuna.tsinghua.edu.cn/simple

2.训练时遇到的问题:训练一开始就NAN,使用小数据集时,acc一直为0:

解决办法是修改了CTCloss初始化:

在class CTCLoss(nn.Module)中

复制代码
    self.loss_func = nn.CTCLoss(blank=0, reduction='none',zero_infinity=True)

遇到问题时给的一些好的参考:

[[深度学习][pytorch][原创]crnn在高版本pytorch上训练loss为nan解决办法_crnn中train loss: nan-博客]( "深度学习pytorch原创crnn在高版本pytorch上训练loss为nan解决办法_crnn中train loss: nan-博客") [关于pytorch自带的CTCloss使用时的注意事项_pytorch ctc-博客]( "关于pytorch自带的CTCloss使用时的注意事项_pytorch ctc-博客")

https://zhuanlan.zhihu.com/p/67415439

然后就没有报错了

3.加载预训练模型代码修改

复制代码
 def load_pretrained_params(model, pretrained_model):

    
     # checkpoint = torch.load(pretrained_model, map_location=torch.device('cpu'))
    
     # model.load_state_dict(checkpoint['state_dict'], strict=False)
    
     backbone_dict = model.state_dict()
    
     pretrained_dict = torch.load(pretrained_model, map_location=torch.device('cpu'))
    
     pretrained_dict_backbone_ = {}
    
     for k, v in pretrained_dict['state_dict'].items():
    
     k_ = k.replace('module.', '')
    
     if k_ in backbone_dict and backbone_dict[k_].size() == v.size():
    
         pretrained_dict_backbone_[k_] = v
    
     else:
    
         print(k_, backbone_dict[k_].size(), v.size())
    
  
    
  
    
     backbone_dict.update(pretrained_dict_backbone_)
    
     model.load_state_dict(backbone_dict)

全部评论 (0)

还没有任何评论哟~