本文介绍了一套基于PyTorch和PyQt5的枸杞与沙棘果实识别系统。该系统采用卷积神经网络模型,实现了90%以上的识别准确率,响应时间小于500ms,显著提升了传统人工分拣效率。系统具备以下特点:1)可视化交互界面,包含分类显示区、控制面板和参数调节功能;2)支持置信度阈值动态调整(50%-95%);3)提供单图/批量图像处理能力。文章详细解析了系统架构、核心模块代码及功能实现,同时指出了当前存在的缺陷(如缺乏早停机制、预处理简单等),并提出了引入主流模型架构、增强数据预处理等改进方向。该系统为农业智能化
在农业智能化转型背景下,果实精准识别成为提升采摘效率的关键技术。本系统基于PyTorch框架构建卷积神经网络模型,结合PyQt5开发可视化界面,实现枸杞与沙棘果实的智能分类识别,为采摘机器人提供核心视觉处理能力。
• 传统人工分拣效率低(约200kg/人/天)
• 两种果实形态相似导致误判率高(人工误判率约15%)
• 野外作业需实时处理能力(响应时间<500ms)
• 实现≥90%的识别准确率
• 开发易用图形界面支持现场部署
• 提供置信度阈值可调机制(50%-95%)
• 支持批量图像自动分类处理
class FruitApp(QWidget):
def __init__(self):
# 界面元素构成
self.medlar_group = QGroupBox('枸杞 ') # 分类结果显示区
self.sea_group = QGroupBox('沙棘 ')
# 核心功能按钮组
self.button_train = QPushButton("开始训练") # 模型训练入口
• 分类结果显示:三列滚动列表分别显示枸杞、沙棘及低置信度样本
• 控制面板:包含训练/测试/清空/退出功能按钮
• 参数调节区:滑动条实时控制分类阈值
• 图像预览:显示当前选中图片的缩略图(400x400px)
self.setWindowTitle('枸杞/沙棘果识别系统') # 设置窗口标题
self.setGeometry(100, 100, 1000, 800) # 设置窗口在屏幕中的位置与大小
# 设置窗口中所有 QPushButton 的统一样式(字体大小、圆角、背景颜色)
self.setStyleSheet("""
QPushButton {
font-size: 16px;
padding: 10px;
border-radius: 5px;
color: white;
}
QPushButton#button_train { background-color: #4CAF50; } # 绿色训练按钮
QPushButton#button_test_folder { background-color: #FFC107; color: black; } # 黄色多图测试按钮
QPushButton#button_test_image { background-color: #9C27B0; } # 紫色单图测试按钮
QPushButton#button_clear { background-color: #2196F3; } # 蓝色清除按钮
QPushButton#button_exit { background-color: #F44336; } # 红色退出按钮
""")
button_train
:绿色(#4CAF50)按钮,用于训练
button_test_folder
:黄色(#FFC107)按钮,测试多个图像,字体为黑色
button_test_image
:紫色(#9C27B0)按钮,测试单张图像
button_clear
:蓝色(#2196F3)按钮,用于清除
button_exit
:红色(#F44336)按钮,退出程序
setWindowTitle(...)
:设置窗口标题为“枸杞/沙棘果识别系统”。
setGeometry(x, y, width, height)
:设置窗口初始位置为 (100, 100)
,宽度为 1000
像素,高度为 800
像素。
self.threshold_slider = QSlider(Qt.Horizontal) # 创建水平方向滑块控件
self.threshold_slider.setMinimum(50) # 设置滑块最小值为 50(即0.50置信度)
self.threshold_slider.setMaximum(95) # 设置滑块最大值为 95(即0.95置信度)
self.threshold_slider.setValue(60) # 默认初始值为 60(即0.60置信度)
self.threshold_slider.valueChanged.connect(self.update_threshold_label) # 滑块值改变时触发函数
self.threshold_label = QLabel(f"当前阈值:{self.threshold_slider.value()/100:.2f}") # 显示阈值文字
QSlider(Qt.Horizontal)
:创建一个水平方向的滑块控件
setMinimum(50)
:设置最小值为 50,表示最低 0.50 置信度
setMaximum(95)
:设置最大值为 95,表示最高 0.95 置信度
setValue(60)
:初始值为 60,对应 0.60 置信度
valueChanged.connect(...)
:当用户移动滑块时,自动调用更新函数
QLabel(...)
:创建一个标签,实时显示当前阈值,保留两位小数
def show_image(self, item):
image_path = item.data(Qt.UserRole) # 获取图像路径(此前存储在 item 的“用户角色”字段中)
subprocess.Popen(['start', '', image_path], shell=True) # 在 Windows 系统中使用默认图片查看器打开该路径
item.data(Qt.UserRole)
:从列表项中取出之前通过 setData(Qt.UserRole, path)
存储的图像路径
subprocess.Popen(['start', '', image_path], shell=True)
:使用 Windows 默认程序打开该图像文件
start
是 Windows 命令
''
是为了占位(防止路径被解释为窗口标题)
shell=True
允许使用 shell 命令
def load_model(self):
if os.path.exists(MODEL_PATH): # 判断模型文件是否存在
try:
self.model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) # 加载参数
self.model.eval() # 设置为评估模式
self.label.setText("状态:已加载模型") # 状态栏提示成功
except Exception as e:
self.label.setText(f"状态:加载模型失败 ({str(e)})") # 异常提示
else:
self.label.setText("状态:未找到模型") # 文件不存在时提示
os.path.exists(MODEL_PATH)
:判断模型文件是否存在
torch.load(..., map_location=torch.device('cpu'))
:将模型加载到 CPU 上(即使是 GPU 保存的)
self.model.load_state_dict(...)
:将保存的权重参数加载进模型结构中
self.model.eval()
:设置模型为评估模式(关闭 Dropout、BatchNorm 等训练特性)
self.label.setText(...)
:更新状态标签,向用户反馈当前加载状态
def process_image(self, image_path):
image = Image.open(image_path).convert('RGB') # 打开并转为 RGB 格式
image = self.transform(image).unsqueeze(0) # 数据预处理并增加 batch 维度
self.model.eval()
with torch.no_grad(): # 禁用梯度以提升推理速度
output = self.model(image) # 前向传播获取预测值
_, predicted = torch.max(output, 1) # 获取预测类别
confidence = torch.softmax(output, dim=1)[0, predicted].item() # 获取最大置信度
label = '枸杞' if predicted.item() == 0 else '沙棘' # 类别标签映射
if confidence < self.threshold: # 若置信度低于阈值则归为“未知”
label = '未知'
Image.open(...).convert('RGB')
:确保图像是 3 通道 RGB 格式
self.transform(image).unsqueeze(0)
:应用预处理,并添加 batch 维度(从 [C,H,W]
变为 [1,C,H,W]
)
self.model.eval()
:切换到推理模式,避免 BatchNorm/Dropout 引入不确定性
torch.no_grad()
:推理阶段禁用自动求导,节省显存和计算
torch.max(output, 1)
:从输出中获得最大概率的类别索引
torch.softmax(...)[0, predicted].item()
:计算对应预测类别的置信度
predicted.item() == 0
:将类别索引映射为中文标签(枸杞 or 沙棘)
if confidence < self.threshold
:如果置信度低于阈值,标记为“未知”
def train_model(model, train_loader, criterion, optimizer, num_epochs=1000, log_file=None):
model.train()
losses = []
for epoch in range(num_epochs):
running_loss = 0.0
for images, labels, _ in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
epoch_loss = running_loss / len(train_loader)
losses.append(epoch_loss)
print(f"Epoch {epoch + 1}, Loss: {epoch_loss:.4f}")
if log_file:
log_file.write(f"Epoch {epoch + 1}, Loss: {epoch_loss:.4f}\n")
model.train()
:确保模型处于训练模式(启用 Dropout、BatchNorm 等)
optimizer.zero_grad()
:每次迭代前清空梯度
model(images)
:获取模型输出
criterion(outputs, labels)
:根据预测和真实标签计算损失
loss.backward()
:反向传播计算梯度
optimizer.step()
:更新模型参数
tqdm(...)
:添加进度条以观察训练进度
log_file.write(...)
:可选,将损失写入日志文件便于持久保存
plt.figure()
plt.plot(range(1, num_epochs + 1), losses) # 横轴为 epoch,纵轴为 loss
plt.title('Training Loss Curve')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid()
plt.savefig('training_loss.png') # 保存图像至当前目录
plt.show() # 显示图像
plt.figure()
:新建一个图像窗口
plt.plot(...)
:绘制训练损失曲线
plt.title(...)
/ xlabel(...)
/ ylabel(...)
:设置图像标题和坐标轴标签
plt.grid()
:让图像更易读
plt.savefig(...)
:保存为 PNG 文件,方便后续查看或报告使用
plt.show()
:在窗口中展示图像(如在脚本中运行)
当前模型训练过程未设置验证集,也未引入早停机制。在训练过程中可能出现过拟合现象,即模型在训练集上表现良好,但在真实应用中准确率下降。此外,若训练时间过长又无停止条件,容易浪费资源。
本项目仅对图像进行了缩放(Resize)和张量化(ToTensor)处理,未包含亮度调整、旋转、裁剪等图像增强操作。这会导致模型对不同拍摄条件的适应能力下降,影响泛化能力。
当前 CNN 模型为轻量级结构,未引入 Dropout、Batch Normalization 等防止过拟合的模块,也未进行特征提取层的深度设计,限制了其分类精度和稳定性。
无法对模型误识别图像进行标记和反馈。这样一来,系统无法自动收集“困难样本”来优化训练集,限制了模型的自我进化能力。
使用 PyQt5 实现的图形界面直观友好,按钮功能齐全,用户易于操作;
支持图像的单图测试与批量识别,兼容性与效率兼具;
模型训练过程提供日志记录与损失曲线可视化,增强开发者调试与追踪的能力;
提供置信度阈值滑块,允许用户灵活调节模型的识别“敏感度”。
增加验证集与早停机制:提升模型训练的科学性,防止过拟合;
引入更复杂的神经网络结构:如使用 ResNet、MobileNet 等主流模型,提升识别精度;
支持用户反馈再训练机制:允许用户对误判图像进行标记,并添加到新训练集中优化模型;
增强图像预处理流程:引入图像增强与数据扩增技术,提高模型的鲁棒性和适应性。
本项目是基于 PyTorch 和 PyQt5 的轻量级农业图像分类系统,适用于科研入门与实际场景落地,特别适合研究图像识别、界面设计及模型训练机制的开发者参考与使用。欢迎有志同道合者共同优化和完善,推动智慧农业的发展!