用Pytorch训练手写签名模型并进行签名识别

整体思路

  • 收集至少两个人的手写签名图片,每个人至少20张
  • 使用Pytorch进行模型训练
  • 使用Flask搭建Web服务
  • 使用Html/JavaScript实现前端调用进行签名识别

项目结构

  • signature-system
    • data
      • train
        • user001
          • 001.png
          • ...
        • user002
          • 001.png
          • ...
    • templates
      • index.html
    • app.py
    • model.py
    • train.py

建模:model.py

import torch
import torch.nn as nn
from efficientnet_pytorch import EfficientNet

class SignatureModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.base = EfficientNet.from_pretrained('efficientnet-b0', in_channels=1)
        in_features = self.base._fc.in_features
        self.base._fc = nn.Identity()
             
        # 主分类器
        self.classifier = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        
        # 异常检测器
        self.anomaly = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )      
        
    def forward(self, x):
        features = self.base(x)
        return {
            'class': self.classifier(features),
            'anomaly': self.anomaly(features),
            'predicted': torch.sigmoid(self.classifier(features))
        }

训练模型

train.py
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import cv2
import numpy as np
import os
from model import SignatureModel
import torch.nn as nn

class SignatureDataset(Dataset):
    def __init__(self, data_dir, img_size=224):
        self.img_size = img_size
        self.classes = sorted(os.listdir(data_dir))
        # print("self.classes:",self.classes)
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
        # print("self.class_to_idx:",self.class_to_idx)
        
        self.image_paths = []
        self.labels = []
        
        for c in self.classes:
            class_dir = os.path.join(data_dir, c)
            for img_file in os.listdir(class_dir):
    

你可能感兴趣的:(#,AI体系学习,pytorch,人工智能,python)