项目:基于深度学习的图像检索系统开发(2)

目录

一.项目背景:

二.项目内容:

2.1 项目步骤

2.1.1数据准备

2.1.2数据库设计

2.1.3特征提取与存储

2.1.4 图像检索实现

2.2遇到的问题及解决

2.2.1 问题1:特征向量维度不匹配

2.2.2 问题2:检索速度慢

2.3项目成果

三.项目心得:

3.1 技术收获

四.  项目代码实现:


一.项目背景:

在数字化浪潮席卷全球的当下,互联网数据量呈爆炸式增长,其承载的数据也正以指数级速度增长,其中视觉信息(图像、视频)的占比日益攀升,已成为信息传递和知识获取的核心载体和信息传播与知识获取的关键媒介。然而,依赖文本关键词的传统检索模式,在应对海量多元的图像数据时,暴露出诸多难以克服的缺陷:

描述瓶颈: 用户往往难以用精确、全面的文字描述复杂的视觉内容(如纹理、风格、特定物体组合、抽象概念)。

表达困境: 面对具有复杂光影、特殊构图或抽象意境的图像,用户难以用精准文字完整描述其视觉细节。​

语义分歧: 由于个体认知差异,不同用户对同一图像的文字表述大相径庭,导致检索结果偏离预期。​

标注匮乏: 大量图像缺乏标准的文本注释,使得基于文字的检索方式难以发挥作用。​

模态割裂: 无法支持用户通过实际图像样本直接检索相关内容,与自然检索需求存在显著脱节。​

主观差异: 不同用户对同一图像的描述可能存在显著差异,导致搜索结果偏差。

信息缺失: 大量图像缺乏有效、规范的文本标注(元数据),使得基于文本的检索无从下手。

跨模态鸿沟: 无法满足用户直接通过视觉信息(如手头图片、屏幕截图)来查找相关内容的自然需求。

这些局限性在多个行业产生了明显的负面效应:​

时尚零售: 消费者发现心仪服饰却不知如何搜索,难以快速找到同款或类似风格商品。​

新闻媒体与创意产业: 编辑需要快速定位相似图片素材,监测图片侵权行为,现有检索方式效率低下。​

医疗诊断: 影像科医生亟需参照同类病例影像资料辅助判断,传统检索难以满足时效性要求。​

文物保护与艺术研究: 鉴定人员需要快速比对文物图像特征,识别赝品或研究艺术风格演变。​

智能交通与公共安全: 利用图像识别技术追踪目标、分析异常行为,现有方法难以满足实时性需求。

数字资产管理/版权保护: 追踪图片在网络上的非法使用、管理大型图库资源需要高效的视觉匹配能力。

电子商务: 消费者看到心仪商品却不知其名,无法有效搜索;寻找风格、款式相似的商品困难重重。

社交媒体与内容平台: 用户希望根据图片发现相关内容、追踪图片传播路径、识别热点视觉趋势或检测重复/侵权内容。

为攻克这些难题,基于内容的图像检索(CBIR)技术逐渐兴起并快速迭代。该技术摒弃文本依赖,直接依据图像的色彩分布、纹理结构、几何形状及深层语义特征进行相似度匹配。随着深度学习技术的蓬勃发展,深度卷积神经网络(CNN)为 CBIR 注入了强劲动力。CNN 凭借强大的特征提取能力,能够从海量图像数据中自动学习高层语义特征,相较于传统人工设计的特征,在图像表征与匹配方面展现出显著优势。​

基于上述背景,本实训项目致力于开发一款深度学习驱动的图像检索系统原型。项目主要任务包括:​

算法落地: 借助预训练的深度 CNN 模型(如 ResNet)提取图像的深度语义特征,精准刻画图像核心信息。​

数据管理: 设计优化的数据存储架构,运用 MySQL 数据库实现大规模图像特征向量的高效存储与管理。​

功能开发: 实现用户上传图像后,系统即时提取特征并与数据库特征进行相似度计算(如欧氏距离),快速返回最相似的 Top-K 图像结果。​

系统验证: 通过搭建 Web 应用平台,测试系统在检索效率、准确率等指标上的性能表现,并在时尚商品检索、艺术作品溯源、医学影像辅助诊断等场景中验证其实际应用价值。​

应用验证:通过构建一个可操作的Web端应用,验证该方案在响应速度、检索准确率等方面的可行性,并展示其在多个潜在应用场景(如商品搜索、版权追踪初步分析、特定类别图像库检索)中的实用价值。

该系统原型整合了深度学习、数据库管理和 Web 开发等多领域技术,不仅是理论知识向实践转化的重要成果,更为解决实际信息检索难题提供了创新思路,兼具学术研究价值与广阔的应用潜力。

二.项目内容:

2.1 项目步骤

2.1.1数据准备

(1)收集猫、狗、花卉图片各300张,总计900张训练图像。

(2)按8:1:1比例划分训练集、验证集和测试集(720:90:90)。

2.1.2数据库设计

创建MySQL数据库 img_features_db,字符集 utf8mb4。

2.1.3特征提取与存储

(1)使用PyTorch加载预训练的ResNet34模型,移除全连接层,保留特征提取部分。

(2)提取所有图片的特征向量(512维),转换为JSON字符串后存入数据库。

2.1.4 图像检索实现

(1)用户上传图片后,提取其特征向量。

(2)计算与数据库中所有特征的余弦相似度,返回Top-K相似结果。

2.2遇到的问题及解决

2.2.1 问题1:特征向量维度不匹配

现象:插入数据库时报错Data too long。

解决:将512维浮点数组转为JSON字符串存储,而非直接拼接字符串。

2.2.2 问题2:检索速度慢

现象:900条数据查询耗时>3秒。

解决:添加索引优化查询,限制返回结果数(Top-10)。

2.3项目成果

(1)成功构建MySQL特征库,存储900张图片特征。

(2)实现Web端图像检索功能,平均响应时间<1秒。

(3)检索准确率:测试集Top-5相似度达92%。

三.项目心得:

3.1 技术收获

(1)掌握了MySQL数据库设计、连接和优化技巧。

(2)深入理解了特征向量存储与相似度计算的工程实现。

(3)实践了PyTorch模型部署与特征提取流程。

四.  项目代码实现:

import os
import torch
import numpy as np
import pymysql
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict
from tqdm import tqdm
import faiss
import time
import matplotlib
# 设置matplotlib中文字体
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'SimSun', 'Arial Unicode MS']  
matplotlib.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 数据库配置
DB_CONFIG = {
    'host': 'rm-7xv8avfp2k7h48m4g7o.mysql.rds.aliyuncs.com',
    'port': 3306,
    'user': 'zhj_one',
    'password': 'Py20040627',
    'database': 'z_h_j_main1',
    'charset': 'utf8mb4'
}

# 图像预处理
def get_transform():
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

# 特征提取模型
class FeatureExtractor:
    def __init__(self, model_path=None, num_classes=3):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
        
        # 如果提供了模型路径,加载模型
        if model_path and os.path.exists(model_path):
            # 先修改模型的全连接层以匹配训练好的模型结构
            self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
            # 加载训练好的权重
            state_dict = torch.load(model_path, map_location=self.device)
            self.model.load_state_dict(state_dict)
            print(f"加载训练好的模型:{model_path}")
        
        # 移除最后的全连接层,只用于特征提取
        self.model = nn.Sequential(*list(self.model.children())[:-1])
        self.model = self.model.to(self.device)
        self.model.eval()
        self.transform = get_transform()
        
    def extract_features(self, img_path):
        img = Image.open(img_path).convert('RGB')
        img_tensor = self.transform(img).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            features = self.model(img_tensor)
        
        # 将特征调整为一维向量
        features = features.view(features.size(0), -1).cpu().numpy()
        return features[0]  # 返回numpy数组
    
    def batch_extract_features(self, img_paths, batch_size=16):
        """批量提取特征,加速处理"""
        all_features = []
        with tqdm(total=len(img_paths), desc="提取特征") as pbar:
            for i in range(0, len(img_paths), batch_size):
                batch_paths = img_paths[i:i+batch_size]
                batch_tensors = []
                
                for img_path in batch_paths:
                    try:
                        img = Image.open(img_path).convert('RGB')
                        img_tensor = self.transform(img).unsqueeze(0)
                        batch_tensors.append(img_tensor)
                    except Exception as e:
                        print(f"处理图片失败 {img_path}: {e}")
                        batch_tensors.append(torch.zeros(1, 3, 224, 224))
                
                if batch_tensors:
                    batch_input = torch.cat(batch_tensors, 0).to(self.device)
                    with torch.no_grad():
                        batch_features = self.model(batch_input)
                        batch_features = batch_features.view(batch_features.size(0), -1).cpu().numpy()
                        all_features.extend(batch_features)
                
                pbar.update(len(batch_paths))
        
        return all_features

# 数据库管理类
class DatabaseManager:
    def __init__(self, config=DB_CONFIG):
        self.config = config
        self.conn = None
        self.cursor = None
    
    def connect(self):
        try:
            self.conn = pymysql.connect(**self.config)
            self.cursor = self.conn.cursor()
            print("数据库连接成功")
            return True
        except Exception as e:
            print(f"数据库连接失败: {e}")
            return False
    
    def close(self):
        if self.cursor:
            self.cursor.close()
        if self.conn:
            self.conn.close()
        print("数据库连接已关闭")
    
    def create_table(self):
        try:
            # 创建图像特征表
            sql = """
            CREATE TABLE IF NOT EXISTS image_features (
                id INT AUTO_INCREMENT PRIMARY KEY,
                image_path VARCHAR(255) NOT NULL,
                class_name VARCHAR(50) NOT NULL,
                features LONGBLOB NOT NULL,
                create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
            """
            self.cursor.execute(sql)
            self.conn.commit()
            print("表创建成功或已存在")
            return True
        except Exception as e:
            print(f"表创建失败: {e}")
            self.conn.rollback()
            return False
    
    def batch_insert_features(self, data_batch):
        """批量插入特征数据"""
        try:
            sql = "INSERT INTO image_features (image_path, class_name, features) VALUES (%s, %s, %s)"
            batch_data = []
            for item in data_batch:
                image_path, class_name, features = item
                binary_features = features.tobytes()
                batch_data.append((image_path, class_name, binary_features))
            
            self.cursor.executemany(sql, batch_data)
            self.conn.commit()
            return True
        except Exception as e:
            print(f"批量插入特征失败: {e}")
            self.conn.rollback()
            return False
    
    def insert_feature(self, image_path, class_name, features):
        try:
            # 将NumPy数组转换为二进制数据
            binary_features = features.tobytes()
            
            sql = "INSERT INTO image_features (image_path, class_name, features) VALUES (%s, %s, %s)"
            self.cursor.execute(sql, (image_path, class_name, binary_features))
            self.conn.commit()
            return True
        except Exception as e:
            print(f"插入特征失败: {e}")
            self.conn.rollback()
            return False
    
    def get_all_features(self):
        try:
            sql = "SELECT id, image_path, class_name, features FROM image_features"
            self.cursor.execute(sql)
            results = self.cursor.fetchall()
            
            feature_data = []
            image_paths = []
            class_names = []
            features_array = []
            
            print("正在加载数据库特征...")
            for row in tqdm(results, desc="加载特征"):
                id, image_path, class_name, binary_features = row
                # 将二进制数据转换回NumPy数组
                features = np.frombuffer(binary_features, dtype=np.float32)
                feature_data.append({
                    'id': id,
                    'image_path': image_path,
                    'class_name': class_name,
                    'features': features
                })
                image_paths.append(image_path)
                class_names.append(class_name)
                features_array.append(features)
            
            return feature_data, np.array(features_array), image_paths, class_names
        except Exception as e:
            print(f"获取特征失败: {e}")
            return [], np.array([]), [], []
    
    def clear_table(self):
        """清空表数据"""
        try:
            sql = "TRUNCATE TABLE image_features"
            self.cursor.execute(sql)
            self.conn.commit()
            print("表数据已清空")
            return True
        except Exception as e:
            print(f"清空表数据失败: {e}")
            self.conn.rollback()
            return False

# FAISS索引管理类
class FaissIndexManager:
    def __init__(self):
        self.index = None
        self.feature_dim = None
    
    def build_index(self, features):
        """构建FAISS索引"""
        if len(features) == 0:
            print("无特征数据,无法构建索引")
            return False
        
        self.feature_dim = features.shape[1]
        # 使用L2距离度量的欧几里德空间索引
        self.index = faiss.IndexFlatL2(self.feature_dim)
        # 归一化特征向量
        faiss.normalize_L2(features)
        # 添加到索引
        self.index = faiss.IndexIDMap(self.index)
        self.index.add_with_ids(features.astype(np.float32), np.array(range(len(features))))
        
        print(f"FAISS索引构建成功,特征维度: {self.feature_dim},特征数量: {len(features)}")
        return True
    
    def search(self, query_vector, top_k=5):
        """搜索最相似的特征"""
        if self.index is None:
            print("索引尚未构建")
            return None
        
        # 确保查询向量形状正确
        query_vector = query_vector.reshape(1, -1).astype(np.float32)
        # 归一化查询向量
        faiss.normalize_L2(query_vector)
        
        # 搜索
        distances, indices = self.index.search(query_vector, top_k)
        
        return distances[0], indices[0]

# 图像检索类
class ImageRetrieval:
    def __init__(self, model_path=None, num_classes=3):
        self.db_manager = DatabaseManager()
        self.feature_extractor = FeatureExtractor(model_path, num_classes)
        self.index_manager = FaissIndexManager()
        self.features_array = None
        self.image_paths = None
        self.class_names = None
        
    def setup(self):
        """设置数据库连接和表"""
        if self.db_manager.connect():
            return self.db_manager.create_table()
        return False
    
    def extract_and_store_features(self, dataset_path):
        """提取数据集中所有图片的特征并存储到数据库"""
        if not os.path.exists(dataset_path):
            print(f"数据集路径不存在: {dataset_path}")
            return False
        
        # 遍历数据集目录收集所有图片路径
        all_images = []
        
        # 遍历数据集目录
        classes = [d for d in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, d))]
        
        print("正在收集图片路径...")
        for class_name in classes:
            class_dir = os.path.join(dataset_path, class_name)
            print(f"处理类别: {class_name}")
            
            for img_name in os.listdir(class_dir):
                img_path = os.path.join(class_dir, img_name)
                if img_path.lower().endswith(('.png', '.jpg', '.jpeg')):
                    all_images.append((img_path, class_name))
        
        print(f"共收集了 {len(all_images)} 张图片")
        
        # 批量提取特征
        batch_size = 32  # 调整批量大小以适应内存
        total_batches = len(all_images) // batch_size + (1 if len(all_images) % batch_size > 0 else 0)
        
        for batch_idx in tqdm(range(total_batches), desc="批量处理"):
            start_idx = batch_idx * batch_size
            end_idx = min((batch_idx + 1) * batch_size, len(all_images))
            batch_images = all_images[start_idx:end_idx]
            
            img_paths = [item[0] for item in batch_images]
            class_names = [item[1] for item in batch_images]
            
            # 提取特征
            features_list = self.feature_extractor.batch_extract_features(img_paths)
            
            # 准备批量插入数据
            data_batch = list(zip(img_paths, class_names, features_list))
            
            # 批量插入数据库
            self.db_manager.batch_insert_features(data_batch)
        
        print("所有图片特征提取并存储完成")
        return True
    
    def load_features_and_build_index(self):
        """从数据库加载特征并构建索引"""
        feature_data, self.features_array, self.image_paths, self.class_names = self.db_manager.get_all_features()
        
        if len(feature_data) == 0:
            print("数据库中没有图片特征数据")
            return False
        
        # 构建FAISS索引
        start_time = time.time()
        success = self.index_manager.build_index(self.features_array)
        end_time = time.time()
        
        if success:
            print(f"索引构建耗时: {end_time - start_time:.2f} 秒")
            return True
        return False
    
    def search_similar_images(self, query_image_path, top_k=5):
        """检索与查询图片最相似的K张图片"""
        if not os.path.exists(query_image_path):
            print(f"查询图片不存在: {query_image_path}")
            return None
        
        # 确保特征和索引已加载
        if self.features_array is None or self.index_manager.index is None:
            success = self.load_features_and_build_index()
            if not success:
                return None
        
        # 提取查询图片的特征
        query_features = self.feature_extractor.extract_features(query_image_path)
        
        # 使用FAISS索引搜索
        start_time = time.time()
        distances, indices = self.index_manager.search(query_features, top_k)
        end_time = time.time()
        
        print(f"检索耗时: {end_time - start_time:.4f} 秒")
        
        # 构造结果
        similar_images = []
        for i, idx in enumerate(indices):
            if idx < len(self.image_paths):
                similar_images.append({
                    'image_path': self.image_paths[idx],
                    'class_name': self.class_names[idx],
                    'similarity': 1.0 - distances[i]  # 转换为相似度
                })
        
        return similar_images
    
    def visualize_results(self, query_image_path, similar_images):
        """可视化检索结果"""
        if not similar_images:
            print("没有找到相似图片")
            return
        
        plt.figure(figsize=(15, 10))
        
        # 显示查询图片
        plt.subplot(3, 3, 1)
        query_img = Image.open(query_image_path).convert('RGB')
        plt.imshow(query_img)
        plt.title("查询图片")
        plt.axis('off')
        
        # 显示相似图片
        for i, img_info in enumerate(similar_images):
            plt.subplot(3, 3, i+2)
            img_path = img_info['image_path']
            img = Image.open(img_path).convert('RGB')
            plt.imshow(img)
            plt.title(f"{img_info['class_name']}\n相似度: {img_info['similarity']:.4f}", fontproperties='SimHei')
            plt.axis('off')
        
        plt.tight_layout()
        plt.savefig('retrieval_results.png', dpi=300, bbox_inches='tight')
        plt.show()
    
    def close(self):
        """关闭数据库连接"""
        self.db_manager.close()

# 主程序
def main():
    # 使用已训练的ResNet34模型
    model_path = "resnet34_output/resnet34_final.pth"
    num_classes = 3  # 猫、狗、花 三个类别
    
    # 初始化图像检索系统
    retrieval = ImageRetrieval(model_path, num_classes)
    
    # 设置数据库
    if not retrieval.setup():
        print("设置数据库失败")
        return
    
    # 自动执行三个步骤
    try:
        # 步骤1: 提取特征并存储到数据库
        print("\n===== 步骤1: 提取特征并存储到数据库 =====")
        # 检查数据库中是否已有足够的特征数据
        _, features_array, _, _ = retrieval.db_manager.get_all_features()
        
        # 如果数据库中已有数据,跳过特征提取步骤
        if len(features_array) > 0:
            print(f"数据库中已有 {len(features_array)} 条特征记录,跳过特征提取步骤")
        else:
            # 从训练集提取特征
            dataset_path = "train_resnet34_data/flower.v2i.folder/train"
            if not os.path.exists(dataset_path):
                print(f"数据集路径不存在: {dataset_path}")
                # 尝试其他可能的路径
                dataset_path = "train_resnet34_data/flower.v2i.folder/valid"
                if not os.path.exists(dataset_path):
                    print(f"备用数据集路径也不存在,请检查数据集位置")
                    return
            
            print(f"从 {dataset_path} 提取特征")
            retrieval.extract_and_store_features(dataset_path)
        
        # 步骤2: 从验证集选择一个随机图片进行检索
        print("\n===== 步骤2: 图像检索 =====")
        # 从验证集中随机选择一张图片作为查询图像
        valid_dir = "train_resnet34_data/flower.v2i.folder/valid"
        if not os.path.exists(valid_dir):
            print(f"验证集路径不存在: {valid_dir}")
            return
        
        # 随机选择一个类别
        classes = [d for d in os.listdir(valid_dir) if os.path.isdir(os.path.join(valid_dir, d))]
        if not classes:
            print("找不到类别目录")
            return
            
        random_class = classes[np.random.randint(0, len(classes))]
        class_dir = os.path.join(valid_dir, random_class)
        
        # 从该类别中随机选择一张图片
        images = [f for f in os.listdir(class_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        if not images:
            print(f"{random_class} 类别中没有图片")
            return
            
        random_image = images[np.random.randint(0, len(images))]
        query_image_path = os.path.join(class_dir, random_image)
        
        print(f"随机选择的查询图片: {query_image_path}")
        top_k = 5  # 返回相似度最高的5张图片
        
        # 执行检索
        similar_images = retrieval.search_similar_images(query_image_path, top_k)
        
        # 步骤3: 可视化展示结果
        print("\n===== 步骤3: 可视化检索结果 =====")
        if similar_images:
            retrieval.visualize_results(query_image_path, similar_images)
            print(f"检索结果已保存为 retrieval_results.png")
            
            # 输出检索结果详情
            print("\n检索结果详情:")
            for i, img in enumerate(similar_images):
                print(f"{i+1}. 类别: {img['class_name']}, 相似度: {img['similarity']:.4f}, 路径: {img['image_path']}")
        else:
            print("未找到相似图片")
    
    except Exception as e:
        print(f"执行过程中出错: {e}")
    finally:
        # 关闭数据库连接
        retrieval.close()
        print("\n处理完成!")

if __name__ == "__main__":
    main() 

你可能感兴趣的:(深度学习,人工智能,python,个人开发,mysql,前端,项目复盘)