ocr指定字符识别范围(两种方法)

使用的是paddleocr的代码

  • 代码下载地址:https://gitee.com/paddlepaddle/PaddleOCR/tree/release/2.6/deploy/cpp_infer
  • 使用的模型
    https://gitee.com/paddlepaddle/PaddleOCR/blob/release/2.6/doc/doc_ch/models_list.md
    在这里插入图片描述
    下载地址:https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_infer.tar
  • 使用的字典
    https://gitee.com/paddlepaddle/PaddleOCR/blob/release/2.6/ppocr/utils/en_dict.txt

原始版本在ocr_rec.cpp里的106行开始

      for (int n = 0; n < predict_shape[1]; n++) {
        // get idx
        argmax_idx = int(Utility::argmax(
            &predict_batch[(m * predict_shape[1] + n) * predict_shape[2]],
            &predict_batch[(m * predict_shape[1] + n + 1) * predict_shape[2]]));
        // get score
        max_value = float(*std::max_element(
            &predict_batch[(m * predict_shape[1] + n) * predict_shape[2]],
            &predict_batch[(m * predict_shape[1] + n + 1) * predict_shape[2]]));
        if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
          score += max_value;
          count += 1;
          str_res += label_list_[argmax_idx];
        }         
         last_index = argmax_idx;
      }
  1. 方法一:
      for (int n = 0; n < predict_shape[1]; n++) 
      {
        // get idx
        argmax_idx = int(Utility::argmax(
            &predict_batch[(m * predict_shape[1] + n) * predict_shape[2]],
            &predict_batch[(m * predict_shape[1] + n + 1) * predict_shape[2]]));
        // get score
        max_value = float(*std::max_element(
            &predict_batch[(m * predict_shape[1] + n) * predict_shape[2]],
            &predict_batch[(m * predict_shape[1] + n + 1) * predict_shape[2]]));

/原始//
// /*针对en_dict.txt这个字典,0~9数字,10~16,43~48,75~93符号,17~42大写字母,49~74小写字母,*/
//           /*字典打印*/
//           // for(int mm=0;mm<100;mm++)
//           // {
//           //   std::cout<<"数字: "<
//           // }
//         /*只识别数字*/
//         //if(argmax_idx<11  || (argmax_idx==96))
//         /*只识别小写字母*/
//          //if(argmax_idx==0 || (argmax_idx<76 && argmax_idx>49)  || (argmax_idx==96))
//         /*只识别大写字母*/
//         //if(argmax_idx==0 ||  ((argmax_idx>17) && (argmax_idx<44))  || (argmax_idx==96))
//         /*识别大写字母和数字*/
//         // if(argmax_idx==0 ||  ((argmax_idx>17) && (argmax_idx<44)) || (argmax_idx<11) || (argmax_idx==96))
//         // {
//           // std::cout<<"--------------------1---------------"<
//           // std::cout<<" argmax_idx "<
//           if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index)))
//           {
//             // std::cout<<"--------------------2---------------"<
//             // std::cout<<" argmax_idx "<
//             score += max_value;
//             count += 1;
//             str_res += label_list_[argmax_idx];
//           }
//         // }
/原始//
        bool Number = true;//数字
        bool Mark = true;//标点
        bool letter = true;//字母
        //标点数字字母全部勾选
        if (Number && Mark && letter)
        {
                  if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index)))
                  {
                    //std::cout<<" argmax_idx "<
                    score += max_value;
                    count += 1;
                    str_res += label_list_[argmax_idx];
                  }     
        }
        else if(!Number && !Mark && !letter)//标点数字字母全部不勾选
        {
          continue;
        }
        else if(Number && !Mark && !letter)//只有数字
        {
          if(argmax_idx<11  || (argmax_idx==96))
              {
                  if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index)))
                  {
                    //std::cout<<" argmax_idx "<
                    score += max_value;
                    count += 1;
                    str_res += label_list_[argmax_idx];
                  }
              }          
        }
        else if(!Number && !Mark && letter)//只有字母
        {
              if((argmax_idx<76 && argmax_idx>49) ||  ((argmax_idx>17) && (argmax_idx<44))  || (argmax_idx==96))
              //if(argmax_idx==0 || (argmax_idx<76 && argmax_idx>49) ||  ((argmax_idx>17) && (argmax_idx<44))  || (argmax_idx==96))
              {
                  if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index)))
                  {
                    //std::cout<<" argmax_idx "<
                    score += max_value;
                    count += 1;
                    str_res += label_list_[argmax_idx];
                  }
              }           
        }
        else if(!Number && Mark && !letter)//只有标点
        {
            if((argmax_idx>=11 && argmax_idx<=17) ||  ((argmax_idx>=44) && (argmax_idx<=49))  || ((argmax_idx>=76) && (argmax_idx<=94))   || (argmax_idx==96) )
            // if(argmax_idx==0 || (argmax_idx>=11 && argmax_idx<=17) ||  ((argmax_idx>=44) && (argmax_idx<=49))  || ((argmax_idx>=76) && (argmax_idx<=94))   || (argmax_idx==96) )
              {
                  if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index)))
                  {
                    //std::cout<<" argmax_idx "<
                    score += max_value;
                    count += 1;
                    str_res += label_list_[argmax_idx];
                  }
              }            
        }
        else if(Number && Mark && !letter)//数字和标点
        {
          if(argmax_idx<=17 ||  ((argmax_idx>=44) && (argmax_idx<=49))  || ((argmax_idx>=76) && (argmax_idx<=94))   || (argmax_idx==96) )
              {
                  if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index)))
                  {
                    //std::cout<<" argmax_idx "<
                    score += max_value;
                    count += 1;
                    str_res += label_list_[argmax_idx];
                  }
              }        
        }
        else if(Number && !Mark && letter)//数字和字母
        {
              if(argmax_idx<11 || (argmax_idx<76 && argmax_idx>49) ||  ((argmax_idx>17) && (argmax_idx<44))  || (argmax_idx==96))
                {
                  if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index)))
                  {
                    //std::cout<<" argmax_idx "<
                    score += max_value;
                    count += 1;
                    str_res += label_list_[argmax_idx];
                  }
                }           
        }
        else if(!Number && Mark && letter)//字母和标点
        {
          if((argmax_idx>=11 && argmax_idx<=94)  || (argmax_idx==96) )
          // if(argmax_idx==0 || (argmax_idx>=11 && argmax_idx<=94)  || (argmax_idx==96) )
              {
                  if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index)))
                  {
                    //std::cout<<" argmax_idx "<
                    score += max_value;
                    count += 1;
                    str_res += label_list_[argmax_idx];
                  }
              }
        }

         last_index = argmax_idx;
      }
  1. 方法二
划定选择范围,如果概率第一大的字符识别不在范围则判断概率第二大的,以此类推
          bool Number = true;//数字
          bool Mark = false;//标点
          bool letter = true;//字母
  for (int n = 0; n < predict_shape[1]; n++) {
    std::vector<float> datas(predict_shape[2]);
    memcpy(&datas[0], &predict_batch[n * predict_shape[2]], sizeof(float)*predict_shape[2]);
    std::vector<int>  idx_list = Argsort(datas);
    // std::cout<<"max: "<
    // std::cout<
    for(int j=idx_list.size()-1;;j--)
    {
  int idx=idx_list[j];
  // /*针对en_dict.txt这个字典,0~9数字,10~16,43~48,75~93符号,17~42大写字母,49~74小写字母,*/
          //标点数字字母全部勾选
          if (Number && Mark && letter)
          {
              argmax_idx = idx;
              break;
          }
          else if(!Number && !Mark && !letter)//标点数字字母全部不勾选
          {
            argmax_idx = 96;
            break;
          }
          else if(Number && !Mark && !letter)//只有数字
          {
            if(idx<11  || (idx==96))
                {
                  argmax_idx = idx;
                  break;                  
                }
             else{
                  continue;
             }             
          }
          else if(!Number && !Mark && letter)//只有字母
          {
            if((idx<76 && idx>49) ||  ((idx>17) && (idx<44))  || (idx==96))
                //if(argmax_idx==0 || (argmax_idx<76 && argmax_idx>49) ||  ((argmax_idx>17) && (argmax_idx<44))  || (argmax_idx==96))
                {
                  argmax_idx = idx;
                  break;                  
                }
            else{
                  continue;
                }           
          }
          else if(!Number && Mark && !letter)//只有标点
          {
              if((idx>=11 && idx<=17) ||  ((idx>=44) && (idx<=49))  || ((idx>=76) && (idx<=94))   || (idx==96) )
              // if(argmax_idx==0 || (argmax_idx>=11 && argmax_idx<=17) ||  ((argmax_idx>=44) && (argmax_idx<=49))  || ((argmax_idx>=76) && (argmax_idx<=94))   || (argmax_idx==96) )
                {
                  argmax_idx = idx;
                  break;                  
                }
             else{
                  continue;
                }            
          }
          else if(Number && Mark && !letter)//数字和标点
          {
            if(idx<=17 ||  ((idx>=44) && (idx<=49))  || ((idx>=76) && (idx<=94))   || (idx==96) )
                {
                  argmax_idx = idx;
                  break;                  
                }
             else{
                  continue;
                }        
          }
          else if(Number && !Mark && letter)//数字和字母
          {
                      // std::cout<<"--------------------"<
                if(idx<11 || (idx<76 && idx>49) ||  ((idx>17) && (idx<44))  || (idx==96))
                  {
                  argmax_idx = idx;
                  break;                  
                }
             else{
                  continue;
                  }           
          }
          else if(!Number && Mark && letter)//字母和标点
          {
            if((idx>=11 && idx<=94)  || (idx==96) )
            // if(argmax_idx==0 || (argmax_idx>=11 && argmax_idx<=94)  || (argmax_idx==96) )
                {
                  argmax_idx = idx;
                  break;                  
                }
             else{
                  continue;
                }
          }
        }

    
    // argmax_idx = int(Argmax(&predict_batch[n * predict_shape[2]],
    //                         &predict_batch[(n + 1) * predict_shape[2]]));
    max_value =
        float(*std::max_element(&predict_batch[n * predict_shape[2]],
                                &predict_batch[(n + 1) * predict_shape[2]]));
    // max_value = predict_batch[argmax_idx];
    if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
      score += max_value;
      count += 1;
      str_res += charactor_dict[argmax_idx];
    }
    last_index = argmax_idx;
  }

你可能感兴趣的:(ocr,c++,算法)