本章核心:通过实现一个分层架构的 SQLite 操作框架,深入学习 Python 数据库交互的最佳实践。我们将涵盖从数据库连接管理、模型定义、数据访问到业务逻辑的完整流程。
核心特性: 内置、无服务器、单文件数据库。遵循 DB-API 2.0 规范。
常用 API 参考表:
API | 描述 |
---|---|
sqlite3.connect(database, ...) |
连接数据库文件,返回 Connection 对象。 |
connection.cursor() |
创建 Cursor 对象。 |
cursor.execute(sql, parameters) |
执行单条 SQL (用 ? 占位符)。 |
cursor.executemany(sql, seq_of_params) |
批量执行 SQL。 |
connection.commit() |
提交事务。 |
connection.rollback() |
回滚事务。 |
cursor.fetchone() |
获取下一行结果 (元组或 Row 对象),无结果时 None 。 |
cursor.fetchall() |
获取所有剩余行结果 (元组或 Row 对象的列表)。 |
cursor.lastrowid |
(属性) 最后 INSERT 的行的 ROWID。 |
cursor.rowcount |
(属性) 最后 DML 操作影响的行数 (-1 表示不确定或不适用)。 |
connection.close() |
关闭连接。 |
connection.row_factory = sqlite3.Row |
(设置) 让查询结果可以通过列名访问(类似字典)。 |
sqlite3.Binary(bytes) |
用于封装要存入 BLOB 字段的二进制数据。 |
with sqlite3.connect(...) as conn: |
(推荐) 使用上下文管理器,自动处理连接关闭和基本事务。 |
框架实现步骤
我们创建以下目录结构,并在每个目录中放入 __init__.py
文件使其成为 Python 包:
sqlite_practice/
├── core/ # 核心功能模块
│ ├── __init__.py
│ ├── db_manager.py # 数据库连接管理
│ └── table_manager.py # 表结构管理
├── models/ # 数据模型
│ ├── __init__.py
│ └── task.py # 任务数据模型
├── repositories/ # 数据访问层
│ ├── __init__.py
│ └── task_repository.py # 任务数据访问
├── services/ # 业务逻辑层
│ ├── __init__.py
│ └── task_service.py # 任务业务逻辑
├── utils/ # 工具类
│ ├── __init__.py
│ └── date_utils.py # 日期处理工具(被删减了)
├── examples/ # 示例代码
│ ├── __init__.py
│ ├── basic_operations.py # 基本操作示例
│ ├── advanced_queries_oop.py # 高级查询示例
│
└── README.md # 项目说明文档
要点:
每个目录下的 __init__.py
文件(可以是空文件)是必需的,它告诉 Python 这个目录应该被视为一个包 (Package),从而允许我们使用点号 .
来导入其中的模块(例如 from core.db_manager import DatabaseManager
)。
这种结构实现了
关注点分离
models
只关心数据长什么样。repositories
只关心如何从数据库存取这些数据。services
只关心如何使用这些数据来完成业务目标。core
提供底层的数据库连接等基础服务。core/
)core/db_manager.py
)此类负责管理数据库连接,提供连接、关闭及上下文管理 (with
语句) 支持。
"""
任务数据模型模块,定义Task相关的数据结构。
# sqlite_practice/models/task.py
"""
from dataclasses import dataclass, field, asdict
from typing import Optional
from datetime import datetime
@dataclass
class Task:
"""任务数据模型类,使用dataclass简化代码"""
title: str # 任务标题(必填)
priority: int = 3 # 优先级(默认3)
is_completed: bool = False # 是否完成(默认False)
task_id: Optional[int] = None # 任务ID(可选,若不指定,则由数据库自动生成)
due_date: Optional[str] = None # 截止日期(可选
attachment: Optional[str] = None # 附件(可选)
created_at: Optional[str] = None # 创建时间(可选)
description: Optional[str] = None # 任务描述(可选)
last_updated: Optional[str] = None # 最后更新时间(可选)
def __post_init__(self):
"""dataclass提供的初始化后自动执行的函数,用于设置默认值"""
if self.created_at is None:
# 如果没有提供当前创建时间,则使用当前时间
self.created_at = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
def to_dict(self) -> dict:
"""将对象转换为字典,用于数据库操作
Returns:
dict: 包含任务数据的字典
"""
# 使用dataclasses.asdict获取基础字典
task_dict = asdict(self)
# SQLite中布尔值以0/1存储,需要转换
task_dict['is_completed'] = 1 if self.is_completed else 0
# 如果task_id为None,从字典中移除它
if self.task_id is None:
task_dict.pop('task_id')
return task_dict
@classmethod
def from_row(cls, row) -> 'Task':
"""从数据库行创建Task对象
这个方法的作用是将数据库查询结果(SQLite行数据)转换为Task对象。
看起来复杂是因为它需要处理多种可能的输入格式:
1. sqlite3.Row对象(有keys方法的字典类对象)
2. 元组或列表形式的结果
虽然dataclass简化了类定义,但不能自动处理从外部数据源(如数据库)
创建对象的过程,尤其是当数据需要类型转换时(如整数到布尔值)。
这种复杂性是为了提高代码的健壮性,确保从不同来源的数据都能正确转换为Task对象。
如果确定数据库始终返回同一格式的结果,可以简化此方法。
Args:
row: sqlite3.Row对象或类似字典/序列的对象
Returns:
Task: 创建的Task对象
"""
# 创建一个包含所有可能属性的字典
task_data = {}
# 检查row中是否有每个属性并添加到task_data
# 使用get方法避免KeyError
if hasattr(row, 'keys'):
# 如果row有keys方法(如sqlite3.Row),使用它
for key in row.keys():
task_data[key] = row[key]
else:
# 否则尝试按照索引获取
try:
task_data = {
'task_id': row[0] if len(row) > 0 else None,
'title': row[1] if len(row) > 1 else None,
'description': row[2] if len(row) > 2 else None,
'priority': row[3] if len(row) > 3 else 3,
'due_date': row[4] if len(row) > 4 else None,
'is_completed': bool(row[5]) if len(row) > 5 else False,
'attachment': row[6] if len(row) > 6 else None,
'created_at': row[7] if len(row) > 7 else None,
'last_updated': row[8] if len(row) > 8 else None
}
except (IndexError, TypeError):
# 如果索引访问失败,返回具有默认值的Task
return cls(title="Unknown")
# 处理布尔值转换
if 'is_completed' in task_data:
task_data['is_completed'] = bool(task_data['is_completed'])
# 创建并返回Task对象
return cls(**task_data)
core/table_manager.py
) `
"""
表管理模块,负责SQLite数据库表的创建和基础操作。
提供表结构定义、创建、修改等功能。
"""
import sqlite3
from typing import Dict, Any, List, Optional
class TableManager:
"""表管理类,负责SQLite数据库表的创建和操作"""
def __init__(self, conn: sqlite3.Connection, cursor: sqlite3.Cursor):
"""初始化表管理器
Args:
conn: 数据库连接对象
cursor: 数据库游标对象
"""
self.conn = conn
self.cursor = cursor
def create_table(self, table_name: str, columns: Dict[str, str]) -> bool:
"""创建数据库表
Args:
table_name: 表名
columns: 表列定义字典,格式为 {列名: 列类型}
Returns:
bool: 表创建是否成功
"""
try:
# 确保连接和游标有效
if not self.conn or not self.cursor:
print(f"[Error] 数据库连接或游标无效")
return False
# 构建列定义字符串
columns_def = []
for col_name, col_type in columns.items():
columns_def.append(f"{col_name} {col_type}")
create_table_sql = f'''CREATE TABLE IF NOT EXISTS {table_name} (
{', '.join(columns_def)}
)'''
self.cursor.execute(create_table_sql)
self.conn.commit()
print(f"[Setup] 表 '{table_name}' 已检查/创建。")
return True
except sqlite3.Error as e:
print(f"[Error] 创建表 '{table_name}' 时出错: {e}")
return False
def add_column(self, table_name: str, column_name: str, column_type: str) -> bool:
"""向已有表添加新列
Args:
table_name: 表名
column_name: 列名
column_type: 列类型和约束
Returns:
bool: 添加列是否成功
"""
try:
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}"
self.cursor.execute(alter_sql)
self.conn.commit()
print(f"[Setup] 已向表 '{table_name}' 添加列 '{column_name} {column_type}'")
return True
except sqlite3.Error as e:
print(f"[Error] 向表 '{table_name}' 添加列时出错: {e}")
return False
def table_exists(self, table_name: str) -> bool:
"""检查表是否存在
Args:
table_name: 表名
Returns:
bool: 表是否存在
"""
try:
self.cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,))
return self.cursor.fetchone() is not None
except sqlite3.Error as e:
print(f"[Error] 检查表 '{table_name}' 是否存在时出错: {e}")
return False
def get_table_info(self, table_name: str) -> List[Dict[str, Any]]:
"""获取表结构信息
Args:
table_name: 表名
Returns:
list: 包含列信息的字典列表
"""
try:
self.cursor.execute(f"PRAGMA table_info({table_name})")
columns = []
for row in self.cursor.fetchall():
column_info = {
'cid': row['cid'],
'name': row['name'],
'type': row['type'],
'notnull': row['notnull'],
'default_value': row['dflt_value'],
'pk': row['pk']
}
columns.append(column_info)
return columns
except sqlite3.Error as e:
print(f"[Error] 获取表 '{table_name}' 信息时出错: {e}")
return []
models/
)@dataclass
简化 Python 类定义在定义 Task
模型之前,我们先了解一下 Python 3.7+ 引入的一个非常有用的工具:@dataclass
装饰器。它能极大地简化用于存储数据的类的编写。
dataclass
是从 Python 3.7 版本开始,作为标准库dataclasses
中的模块被引入的。 随着 Python 版本的不断更新,dataclass
也逐步发展和完善,为 Python 开发者提供了更加便捷的数据类创建和管理方式。
dataclass
的主要功能在于帮助我们简化数据类的定义过程。 本文总结了几个我们在此框架中会用到的dataclass
技巧。
1. 传统的类定义方式
回顾一下,如果不用 dataclass
,定义一个简单的 CoinTrans
类(包含交易 ID, 交易对, 价格, 是否成功, 地址列表)大致如下:
# 传统方式定义类
class CoinTransTraditional:
def __init__(
self,
id: str,
symbol: str,
price: float,
is_success: bool,
addrs: list,
) -> None:
self.id = id
self.symbol = symbol
self.price = price
self.is_success = is_success
self.addrs = addrs
# 为了方便打印,需要自己实现 __str__ 或 __repr__
def __repr__(self) -> str:
# repr 通常返回更详细、无歧义的表示
return (f"CoinTransTraditional(id='{self.id}', symbol='{self.symbol}', "
f"price={self.price}, is_success={self.is_success}, addrs={self.addrs})")
# 使用传统类
if __name__ == "__main__":
coin_trans_trad = CoinTransTraditional("id01", "BTC/USDT", 71000.0, True, ["0x1111", "0x2222"])
# 需要实现 __repr__ 或 __str__ 才能得到有意义的打印输出
print(coin_trans_trad)
如你所见,我们需要编写 __init__
方法来初始化所有属性,还需要编写 __repr__
(或 __str__
) 方法才能方便地打印对象内容。
2. 使用 @dataclass
装饰器定义类
现在看看用 @dataclass
有多简单:
from dataclasses import dataclass, field # 导入 dataclass 和 field
from typing import List # 使用 List 进行类型提示
# 使用 @dataclass 装饰器
@dataclass
class CoinTransDataclass:
# 只需声明属性及其类型提示
id: str
symbol: str
price: float
is_success: bool
addrs: List[str] # 使用 List[str] 提供更精确的类型提示
# 使用 dataclass
if __name__ == "__main__":
coin_trans_dc = CoinTransDataclass("id01", "BTC/USDT", 71000.0, True, ["0x1111", "0x2222"])
# dataclass 自动生成了 __init__ 和 __repr__ 方法!
print(coin_trans_dc)
运行结果会直接打印出易于阅读的对象表示:
CoinTransDataclass(id='id01', symbol='BTC/USDT', price=71000.0, is_success=True, addrs=['0x1111', '0x2222'])
关键优势:
@dataclass
会自动为你生成 __init__
、__repr__
、__eq__
(等值比较) 等常用方法,大大减少样板代码。2.1 默认值与 default_factory
设置默认值很简单,直接在属性后面赋值即可。但对于可变类型(如 list
, dict
)作为默认值,直接赋值会引发 ValueError
,因为所有实例会共享同一个可变对象,这通常不是我们想要的。需要使用 field
函数和 default_factory
来指定一个工厂函数(一个无参可调用对象,返回默认值)。
from dataclasses import dataclass, field
from typing import List
# 工厂函数,用于生成默认的列表
def default_addr_list() -> List[str]:
print(" (调用了 default_addr_list 工厂函数)") # 演示工厂函数何时被调用
return ["0xdefault1", "0xdefault2"]
@dataclass
class CoinTransWithDefaults:
id: str = "default_id" # 字符串默认值
symbol: str = "BTC/USDT" # 字符串默认值
price: float = 0.0 # 数值默认值
is_success: bool = False # 布尔默认值
# 对于 list (可变类型),使用 default_factory 指定生成函数
addrs: List[str] = field(default_factory=default_addr_list)
# 另一个例子: 默认空列表
# related_ids: List[str] = field(default_factory=list)
# 使用带默认值的 dataclass
if __name__ == "__main__":
print("\n--- Dataclass 默认值示例 ---")
# 不提供参数,将使用默认值
default_trans = CoinTransWithDefaults()
print("创建第一个实例:")
print(default_trans)
# 再创建一个实例,验证列表是独立的
print("创建第二个实例:")
another_default = CoinTransWithDefaults()
print(another_default)
# 修改第一个实例的列表,不应影响第二个
default_trans.addrs.append("0xadded")
print("修改第一个实例的 addrs 后:")
print(f" 实例1: {default_trans}")
print(f" 实例2: {another_default}") # 实例2的列表未受影响
2.2 隐藏敏感信息 (repr=False
)
如果你不希望某些字段出现在 print()
(即 __repr__
) 的输出中(例如密码、密钥等),可以在 field
中设置 repr=False
。
@dataclass
class SensitiveData:
username: str
session_id: str
# 不希望 token 出现在 repr 输出中
token: str = field(repr=False)
# 也可以给 token 设置默认值或 default_factory
# token: str = field(default="dummy_token", repr=False)
if __name__ == "__main__":
print("\n--- Dataclass 隐藏字段示例 ---")
data = SensitiveData("user1", "sess123", token="secret_token_value")
print(data) # 输出中不会包含 token
# 但仍然可以访问该属性
print(f" 访问隐藏的 token: {data.token}")
2.3 只读对象 (frozen=True
)
如果你希望创建的对象是不可变的(创建后其属性值不能被修改),可以在 @dataclass
装饰器中设置 frozen=True
。这对于表示常量数据或确保数据不被意外篡改很有用。
@dataclass(frozen=True) # 设置为只读
class ImmutablePoint:
x: float
y: float
if __name__ == "__main__":
print("\n--- Dataclass 只读对象示例 ---")
p = ImmutablePoint(1.0, 2.0)
print(p)
# 尝试修改属性会抛出 FrozenInstanceError
try:
p.x = 5.0
except Exception as e:
print(f" 尝试修改只读对象失败: {type(e).__name__}: {e}")
2.4 转换为元组和字典 (astuple
, asdict
)
dataclasses
模块提供了 astuple
和 asdict
函数,可以方便地将 dataclass 实例转换为元组或字典,这在与其他库或模块交互时非常方便。
from dataclasses import astuple, asdict, dataclass
@dataclass
class SimpleConfig:
host: str = "localhost"
port: int = 8080
debug_mode: bool = False
if __name__ == "__main__":
print("\n--- Dataclass 转换示例 ---")
config = SimpleConfig(port=9000)
# 转换为元组
config_tuple = astuple(config)
print(f" 转换为元组: {config_tuple}")
# 转换为字典
config_dict = asdict(config)
print(f" 转换为字典: {config_dict}")
总结 (@dataclass
)
在 Python 中,数据类主要用于存储数据。 定义数据类时,通常需要编写一些重复性的代码,如构造函数 (
__init__
)、字符串表示 (__repr__
,__str__
) 等。@dataclass
装饰器的出现,使得这些通用方法的生成变得自动化,从而极大地简化了数据类的定义过程,提高了开发效率,是我们在数据建模时非常有用的工具。
现在,我们将使用 @dataclass
来定义我们的 Task
模型。
models/task.py
- 使用 @dataclass)"""
任务数据模型模块,定义Task相关的数据结构。
# sqlite_practice/models/task.py
"""
from dataclasses import dataclass, field
from typing import Optional
from datetime import datetime
@dataclass
class Task:
"""任务数据模型类,使用dataclass简化代码"""
title: str # 任务标题,必填
description: Optional[str] = None # 任务描述,可选
priority: int = 3 # 优先级,默认为3
due_date: Optional[str] = None # 截止日期,可选
is_completed: bool = False # 是否完成,默认为False
attachment: Optional[bytes] = None # 附件,可选
created_at: Optional[str] = None # 创建时间,可选
last_updated: Optional[str] = None # 最后更新时间,可选
task_id: Optional[int] = None # 任务ID,由数据库生成
def __post_init__(self):
"""对象初始化后运行的方法,用于设置默认值"""
if self.created_at is None:
# 如果没有提供创建时间,设置为当前时间
self.created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def to_dict(self) -> dict:
"""将对象转换为字典,用于数据库操作
Returns:
dict: 包含任务数据的字典
"""
task_dict = {
'title': self.title,
'description': self.description,
'priority': self.priority,
'due_date': self.due_date,
'is_completed': 1 if self.is_completed else 0, # SQLite中布尔值以0/1存储
'attachment': self.attachment,
'created_at': self.created_at,
'last_updated': self.last_updated
}
# 只有在task_id存在时才添加到字典中
if self.task_id is not None:
task_dict['task_id'] = self.task_id
return task_dict
@classmethod
def from_row(cls, row) -> 'Task':
"""从数据库行创建Task对象
Args:
row: sqlite3.Row对象或类似字典的对象
Returns:
Task: 创建的Task对象
"""
# 创建一个包含所有可能属性的字典
task_data = {}
# 检查row中是否有每个属性并添加到task_data
# 使用get方法避免KeyError
if hasattr(row, 'keys'):
# 如果row有keys方法(如sqlite3.Row),使用它
for key in row.keys():
task_data[key] = row[key]
else:
# 否则尝试按照索引获取
try:
task_data = {
'task_id': row[0] if len(row) > 0 else None,
'title': row[1] if len(row) > 1 else None,
'description': row[2] if len(row) > 2 else None,
'priority': row[3] if len(row) > 3 else 3,
'due_date': row[4] if len(row) > 4 else None,
'is_completed': bool(row[5]) if len(row) > 5 else False,
'attachment': row[6] if len(row) > 6 else None,
'created_at': row[7] if len(row) > 7 else None,
'last_updated': row[8] if len(row) > 8 else None
}
except (IndexError, TypeError):
# 如果索引访问失败,返回具有默认值的Task
return cls(title="Unknown")
# 处理布尔值转换
if 'is_completed' in task_data:
task_data['is_completed'] = bool(task_data['is_completed'])
# 创建并返回Task对象
return cls(**task_data)
知识点解释 (@dataclass
版 Task
):
@dataclass
装饰器: 应用在类定义之前,自动添加 __init__
, __repr__
, __eq__
等方法。title: str
, priority: int = 3
)。Optional[T]
表示该属性可以是 T
类型或 None
。priority: int = 3
)。field()
函数 (可选): 用于更精细地控制字段行为,如设置 default_factory
(用于可变默认值,本例未使用但已在知识点中介绍), repr=False
(不在打印输出中显示), init=False
(不由 __init__
初始化), compare=False
(不在等值比较中考虑) 等。to_dict()
方法: 方便将对象状态转换为字典,特别用于准备插入或更新到数据库的数据。注意这里手动处理了 is_completed
从 bool
到 int
的转换,以匹配 SQLite 的存储方式。@classmethod from_row(cls, row)
: 这是一个类方法(第一个参数是类本身 cls
,而不是实例 self
),专门用于从数据库返回的行数据(这里是 sqlite3.Row
对象)创建 Task
类的实例。这是将数据库数据映射回对象的常用模式。它负责处理列名到属性的映射和必要的数据类型转换(如 bool(row['is_completed'])
)。增加了错误处理,确保输入有效。repositories/
)repositories/task_repository.py
)此类封装所有与 tasks
表相关的 SQL 操作,并使用 Task
dataclass 进行数据交互。
"""
任务数据访问层模块,实现对Task数据的CRUD操作。
"""
import sqlite3
from typing import List, Optional, Dict, Any, Tuple
from models.task import Task
from core.table_manager import TableManager
class TaskRepository:
"""任务数据访问类,实现对Task表的增删改查操作"""
def __init__(self, conn: sqlite3.Connection, cursor: sqlite3.Cursor):
self.conn = conn
self.cursor = cursor
self.table_name = 'tasks' # 表名
self.table_manager = TableManager(self.conn, self.cursor)
def create_table(self) -> bool:
"""创建任务表"""
columns = {
"task_id": "INTEGER PRIMARY KEY AUTOINCREMENT", # 任务ID
"title": "TEXT NOT NULL", # 任务标题
"description": "TEXT", # 任务描述
"priority": "INTEGER DEFAULT 3", # 优先级,默认3
"due_date": "DATE", # 截止日期
"is_completed": "BOOLEAN DEFAULT 0", # 是否完成,默认False
"attachment": "BLOB", # 附件,用于存储文件
"created_at": "TIMESTAMP DEFAULT CURRENT_TIMESTAMP", # 创建时间
"last_updated": "TIMESTAMP" # 最后更新时间
}
# 调用TableManager的create_table方法创建任务表
created_flag = self.table_manager.create_table(self.table_name, columns)
return created_flag
def add_column(self, column_name: str, column_type: str) -> bool:
"""向任务表添加新列
Args:
column_name: 新列名
column_type: 新列类型
"""
added_flag = self.table_manager.add_column(self.table_name, column_name, column_type)
return added_flag
def get_table_info(self) -> List[Dict[str, Any]]:
"""获取任务表结构信息
Returns:
list: 包含列信息的字典列表
"""
table_info = self.table_manager.get_table_info(self.table_name)
return table_info
def insert_task(self, task: Task) -> Optional[int]:
"""插入新任务
Args:
task: Task对象
Returns:
int: 新任务的ID,如果插入失败返回None
"""
try:
# 将Task对象转换为字典
task_dict = task.to_dict()
# 如果存在task_id,则从字典中移除,因为它是自增的
if 'task_id' in task_dict:
del task_dict['task_id']
# 构建INSERT语句
columns = ', '.join(task_dict.keys())
"""
placeholders的执行逻辑:类似于下方的代码
task_dict = {'task1': 'value1', 'task2': 'value2', 'task3': 'value3'}
question_marks = ', '.join(['?' for _ in task_dict])
print(question_marks) # 输出: ?, ?, ?
'"""
placeholders = placeholders = ', '.join(['?' for _ in task_dict])
values = tuple(task_dict.values())
insert_sql = f"INSERT INTO {self.table_name} ({columns}) VALUES ({placeholders})"
# 执行插入操作
self.cursor.execute(insert_sql, values)
self.conn.commit()
# 获取新插入记录的ID
inserted_id = self.cursor.lastrowid
print(f"[Info] 成功插入任务: '{task.title}', ID: {inserted_id}")
return inserted_id
except sqlite3.Error as e:
print(f"[Error] 任务插入失败: {e}")
self.conn.rollback()
return None
def insert_many(self, tasks: List[Task]) -> Optional[int]:
"""批量插入任务记录
Args:
tasks: 要插入的Task对象列表
Returns:
int: 成功插入的记录数,如果插入失败则返回None
"""
try:
# 确保有任务要插入
if not tasks:
print("[Warning] 没有任务要插入")
return 0
# 将所有Task对象转换为字典
task_dicts = [task.to_dict() for task in tasks]
# 移除所有task_id,因为它们是自增的
for task_dict in task_dicts:
if 'task_id' in task_dict:
del task_dict['task_id']
# 确保所有任务的字段相同
columns = list(task_dicts[0].keys())
placeholders = ', '.join(['?' for _ in columns])
# 构建INSERT语句
insert_sql = f"INSERT INTO {self.table_name} ({', '.join(columns)}) VALUES ({placeholders})"
# 准备批量插入的值
# 这段代码使用列表推导式创建一个包含所有任务数据的列表:
# 1. 外层循环遍历每个任务字典(task_dicts中的每个task_dict)
# 2. 内层循环按照columns中定义的列顺序提取每个任务字典中的值
# 3. 使用tuple()将每个任务的值转换为元组,确保与executemany()方法兼容
# 4. 最终生成的values是一个元组列表,每个元组包含一个任务的所有字段值
# 例如: [(任务1的值1, 任务1的值2...), (任务2的值1, 任务2的值2...), ...]
values = [tuple(task_dict[col] for col in columns) for task_dict in task_dicts]
print(f"[Info] 准备批量插入{len(values)}条任务")
# 执行批量插入
self.cursor.executemany(insert_sql, values)
self.conn.commit()
# 获取新插入记录的数量
inserted_count = self.cursor.rowcount
print(f"[Info] 成功插入{inserted_count}条任务")
return inserted_count
except sqlite3.Error as e:
print(f"[Error] 任务插入失败: {e}")
self.conn.rollback()
return None
def update_task(self, task: Task) -> bool:
"""更新任务记录
Args:
task: 要更新的Task对象,必须包含task_id
Returns:
bool: 是否更新成功
"""
try:
# 检查task_id是否存在
if task.task_id is None:
print("[Error] 无法更新任务:缺少task_id")
return False
# 将Task对象转换为字典
task_dict = task.to_dict()
# 从字典中移除task_id,他讲在WHERE子句中使用
task_id = task_dict.pop('task_id')
# 构建SET字句
"""
这将生成类似于下面的SQL语句:
UPDATE tasks SET title = 'task1', description = 'task1 description', priority = 1, due_date = '2022-01-01', is_completed = 1, attachment = b'123456', last_updated = '2022-01-01 00:00:00' WHERE task_id = 1
"""
set_clause = ', '.join([f"{col} = ?" for col in task_dict.keys()])
# 检查task_dict中是否已包含last_updated字段
if 'last_updated' not in task_dict:
# 如果没有包含last_updated字段,则使用SQLite的CURRENT_TIMESTAMP函数
# 这样可以确保last_updated字段使用数据库服务器的当前时间
set_clause += ", last_updated = CURRENT_TIMESTAMP"
# 构建UPDATE语句
update_sql = f"UPDATE {self.table_name} SET {set_clause} WHERE task_id = ?"
# 准备参数值
values = list(task_dict.values())
values.append(task_id)
# 执行更新操作
self.cursor.execute(update_sql, values)
self.conn.commit()
# 检查是否有行被更新
updated_count = self.cursor.rowcount
if updated_count > 0:
print(f"[Info] 成功更新任务 ID={task_id}")
return True
else:
print(f"[Warning] 未找到要更新的任务 ID={task_id}")
return False
except sqlite3.Error as e:
print(f"[Error] 任务更新失败: {e}")
self.conn.rollback()
return False
def delete_task(self, task_id: int) -> bool:
"""删除任务记录
Args:
task_id: 要删除的任务ID
Returns:
bool: 是否删除成功
"""
try:
# 构建DELETE语句
delete_sql = f"DELETE FROM {self.table_name} WHERE task_id = ?"
# 执行删除操作
self.cursor.execute(delete_sql, (task_id,))
# 检查是否有行被删除
deleted_count = self.cursor.rowcount
if deleted_count > 0:
print(f"[Info] 成功删除任务 ID={task_id}")
return True
else:
print(f"[Warning] 未找到要删除的任务 ID={task_id}")
return False
except sqlite3.Error as e:
print(f"[Error] 任务删除失败: {e}")
self.conn.rollback()
return False
def find_task_by_id(self, task_id: int) -> Optional[Task]:
"""根据任务ID查找任务记录"""
try:
# 构建SELECT语句
select_sql = f"SELECT * FROM {self.table_name} WHERE task_id = ?"
# 执行查询操作
self.cursor.execute(select_sql, (task_id,))
# 获取查询结果
row = self.cursor.fetchone()
# 如果未找到,返回None
if row is None:
return None
# 将结果转换为Task对象
return Task.from_row(row)
except sqlite3.Error as e:
print(f"[Error] 任务查找失败: {e}")
return None
def find_all_tasks(self) -> List[Task]:
"""查找所有任务记录"""
try:
# 构建SELECT语句
select_sql = f"SELECT * FROM {self.table_name}"
# 执行查询操作
self.cursor.execute(select_sql)
# 获取所有结果
rows = self.cursor.fetchall()
# 将结果转换为Task对象列表
return [Task.from_row(row) for row in rows]
except sqlite3.Error as e:
print(f"[Error] 任务查找失败: {e}")
return []
def find_by_criteria(self, criteria: Dict[str, Any]) -> List[Task]:
"""根据条件查找任务
Args:
criteria: 查询条件字典,格式为 {列名: 值}
Returns:
list: 符合条件的Task对象列表
"""
try:
# 确保有条件要查询
if not criteria:
return self.find_all_tasks()
# 构建WHERE子句
values = [] # 用于占位符
where_clauses = [] # 用于条件
for col, val in criteria.items():
where_clauses.append(f"{col} = ?")
values.append(val)
where_clause = ' AND '.join(where_clauses) # 例如:查询条件: priority = ? 若有多个条件,则用AND连接
# 构建SELECT语句
select_sql = f"SELECT * FROM {self.table_name} WHERE {where_clause}"
# 执行查询操作
self.cursor.execute(select_sql, values)
# 获取所有结果
rows = self.cursor.fetchall()
# 奖结果转换为Task对象列表
return [Task.from_row(row) for row in rows]
except sqlite3.Error as e:
print(f"[Error] 任务查找失败: {e}")
return []
def find_by_title_contains(self, title_part: str) -> List[Task]:
"""根据标题部分查找任务"""
try:
# 构建SELECT语句 使用LIKE模糊查询
select_sql = f"SELECT * FROM {self.table_name} WHERE title LIKE ?"
# 执行查询,使用%作为通配符
self.cursor.execute(select_sql, (f'%{title_part}%',))
# 获取所有结果
rows = self.cursor.fetchall()
# 将结果转换为Task对象列表
return [Task.from_row(row) for row in rows]
except sqlite3.Error as e:
print(f"[Error] 任务查找失败: {e}")
return []
知识点解释 (Repository - 适配 Dataclass):
insert
和 update
方法现在接收 Task
dataclass 实例。它们使用 task.to_dict()
来获取需要持久化的数据,而不是手动从对象属性中提取。from_row
的使用: find_all
, find_by_id
, find_by_criteria
等查询方法现在依赖 Task.from_row(row)
这个类方法来将数据库行数据(sqlite3.Row
对象)转换回 Task
dataclass 实例。这使得模型转换逻辑集中在模型类自身。insert
/update
时将 Python bool
转换为数据库的 INTEGER
(0/1),在 from_row
时将数据库的 INTEGER
转换回 bool
。insert
和 update
方法中动态构建了 SQL 语句的列名和占位符部分。这样做可以使代码适应模型的变化,但必须极其小心,确保列名不是来自用户输入,以防范 SQL 注入。这里列名来自 task.to_dict().keys()
,是安全的。sqlite3.Error
和 sqlite3.IntegrityError
的捕获以及事务回滚。查询方法在出错时返回空列表 []
或 None
。create_table
职责: 将创建表的逻辑放在 Repository 初始化时执行,确保操作 Repository 前表结构一定存在,这是一种常见的实践模式。services/
)services/task_service.py
)服务层的代码通常与底层数据模型是 dataclass 还是普通类关系不大,因为它主要通过 Repository 提供的接口(接收和返回 Task
对象)来工作。因此,之前的 TaskService
代码基本可以直接使用,或者做少量调整以利用 dataclass 的特性(如果需要的话)。
"""
任务服务模块,提供业务逻辑层功能。
负责处理任务相关的业务规则和操作。
# sqlite_practice/services/task_service.py
"""
from typing import List, Optional, Dict, Any
from datetime import datetime, timedelta
from models.task import Task
from repositories.task_repository import TaskRepository
class TaskService:
"""任务服务类,处理任务相关的业务逻辑"""
def __init__(self, task_repository: TaskRepository):
"""初始化任务服务
Args:
task_repository: 任务数据访问对象
"""
self.task_repository = task_repository
def create_task(self, title: str, description: str = None, priority: int = 3,
due_date: str = None, is_completed: bool = False,
attachment: bytes = None) -> Optional[int]:
"""创建新任务
Args:
title: 任务标题
description: 任务描述
priority: 优先级 (1-5),默认为3
due_date: 截止日期,格式为'YYYY-MM-DD'
is_completed: 是否已完成
attachment: 附件数据
Returns:
int: 新创建任务的ID,失败则返回None
"""
# 验证优先级范围
if priority < 1 or priority > 5:
print(f"[Warning] 优先级 {priority} 超出范围 (1-5),将使用默认值 3")
priority = 3
# 验证日期格式
if due_date:
try:
# 尝试解析日期格式
datetime.strptime(due_date, '%Y-%m-%d')
except ValueError:
print(f"[Warning] 日期格式 '{due_date}' 无效,应为 'YYYY-MM-DD',将设为空")
due_date = None
# 验证标题不为空
if not title:
print("[Error] 任务标题不能为空")
return None
# 创建任务对象
task = Task(
title=title,
description=description,
priority=priority,
due_date=due_date,
is_completed=is_completed,
attachment=attachment
)
# 保存任务
return self.task_repository.insert_task(task)
def create_tasks_batch(self, task_data_list: List[Dict[str, Any]]) -> Optional[int]:
"""批量创建任务
Args:
task_data_list: 包含任务数据的字典列表
Returns:
int: 成功创建的任务数量,失败则返回None
"""
if not task_data_list:
print("[Warning] 没有任务数据需要创建")
return 0
tasks = []
for task_data in task_data_list:
# 验证必要字段
if 'title' not in task_data or not task_data['title']:
print("[Warning] 跳过缺少标题的任务")
continue
# 创建任务对象
task = Task(
title=task_data['title'],
description=task_data.get('description'),
priority=task_data.get('priority', 3),
due_date=task_data.get('due_date'),
is_completed=task_data.get('is_completed', False),
attachment=task_data.get('attachment')
)
tasks.append(task)
if not tasks:
print("[Warning] 没有有效的任务需要创建")
return 0
# 批量保存任务
return self.task_repository.insert_many(tasks)
def update_task(self, task_id: int, **kwargs) -> bool:
"""更新任务
Args:
task_id: 要更新的任务ID
**kwargs: 要更新的字段和值
Returns:
bool: 更新是否成功
"""
# 查找现有任务
task = self.task_repository.find_task_by_id(task_id)
if not task:
print(f"[Error] 未找到要更新的任务 ID={task_id}")
return False
# 更新字段
for key, value in kwargs.items():
if hasattr(task, key):
setattr(task, key, value)
else:
print(f"[Warning] 忽略未知字段 '{key}'")
# 保存更新
return self.task_repository.update_task(task)
def complete_task(self, task_id: int) -> bool:
"""将任务标记为已完成
Args:
task_id: 要标记的任务ID
Returns:
bool: 操作是否成功
"""
return self.update_task(task_id, is_completed=True)
def delete_task(self, task_id: int) -> bool:
"""删除任务
Args:
task_id: 要删除的任务ID
Returns:
bool: 删除是否成功
"""
return self.task_repository.delete_task(task_id)
def get_task(self, task_id: int) -> Optional[Task]:
"""获取指定ID的任务
Args:
task_id: 任务ID
Returns:
Task: 任务对象,未找到则返回None
"""
return self.task_repository.find_task_by_id(task_id)
def get_all_tasks(self) -> List[Task]:
"""获取所有任务
Returns:
list: 任务对象列表
"""
return self.task_repository.find_all_tasks()
def get_tasks_by_priority(self, priority: int) -> List[Task]:
"""获取指定优先级的任务
Args:
priority: 优先级
Returns:
list: 任务对象列表
"""
return self.task_repository.find_by_criteria({'priority': priority})
def get_incomplete_tasks(self) -> List[Task]:
"""获取未完成的任务
Returns:
list: 任务对象列表
"""
return self.task_repository.find_by_criteria({'is_completed': 0}) # SQLite中布尔值存储为0/1
def get_overdue_tasks(self) -> List[Task]:
"""获取已逾期的任务
Returns:
list: 任务对象列表
"""
today = datetime.now().strftime('%Y-%m-%d')
# 获取所有任务
all_tasks = self.task_repository.find_all_tasks()
# 筛选出截止日期已过且未完成的任务
overdue_tasks = []
for task in all_tasks:
if (task.due_date and task.due_date < today and not task.is_completed):
overdue_tasks.append(task)
return overdue_tasks
def search_tasks_by_title(self, title_part: str) -> List[Task]:
"""搜索标题包含指定字符串的任务
Args:
title_part: 标题中要搜索的文本
Returns:
list: 符合条件的任务对象列表
"""
return self.task_repository.find_by_title_contains(title_part)
def get_tasks_due_within_days(self, days: int) -> List[Task]:
"""获取指定天数内到期的任务
Args:
days: 天数
Returns:
list: 任务对象列表
"""
today = datetime.now()
end_date = (today + timedelta(days=days)).strftime('%Y-%m-%d')
today_str = today.strftime('%Y-%m-%d')
# 获取所有任务
all_tasks = self.task_repository.find_all_tasks()
# 筛选出在指定日期范围内到期的任务
due_tasks = []
for task in all_tasks:
if (task.due_date and today_str <= task.due_date <= end_date):
due_tasks.append(task)
return due_tasks
知识点解释 (Service Layer):
Task
dataclass 实例,Service 层的方法(如 complete_task
, update_task_details
)可以直接操作这些对象的属性,代码更简洁。create_task
中增加了对 priority
和 due_date
格式的校验。complete_task
增加了对任务是否已完成的检查,避免重复操作。update_task_details
增加了对不可修改字段 (task_id
, created_at
) 的保护,并对传入的字段和值进行基本检查。get_tasks_due_within_days
中日期的解析和比较逻辑清晰地放在服务层。self.task_repository
。examples/
)现在我们已经构建了框架的各个组件(DatabaseManager
, Task
模型, TaskRepository
, TaskService
),接下来将通过几个具体的示例脚本来演示如何使用这个框架完成不同的数据库操作任务。
注意: 运行这些示例脚本时,请确保:
sqlite3
是内置的,但未来章节可能需要 pymysql
, sqlalchemy
等)。core
, models
, repositories
, services
, utils
)。这通常意味着你需要从 sqlite_practice
这个项目的根目录来运行这些示例脚本,或者确保 sqlite_practice
目录位于 Python 的模块搜索路径 (sys.path
) 中。示例代码中会包含尝试处理路径的代码。example_sqlite.db
, advanced_queries.db
, transaction_example.db
),这些文件会出现在脚本运行的目录下或指定的 data
子目录中。examples/basic_operations.py
)这个脚本演示了最核心的 CRUD (创建、读取、更新、删除) 操作,展示了如何通过 TaskRepository
来与数据库交互。
"""
基本数据库操作示例模块,展示SQLite数据库的基本CRUD操作。
包括连接、创建表、增删改查等基本功能。
"""
import os
import sqlite3
from typing import List, Optional
from core.db_manager import DatabaseManager
from core.table_manager import TableManager
from models.task import Task
from repositories.task_repository import TaskRepository
class SQLiteExample:
"""SQLite基本操作示例类,实现单例模式"""
_instance = None
def __new__(cls, *args, **kwargs):
"""确保类只有一个实例"""
if cls._instance is None:
cls._instance = super(SQLiteExample, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, db_file: str = "sqlite_practice.db"):
"""初始化SQLite示例类"""
# 避免重复初始化
if self._initialized:
return
self.db_file = db_file
self.db_manager = DatabaseManager(self.db_file)
self.conn, self.cursor = self.db_manager.connect()
self.task_repository = TaskRepository(self.conn, self.cursor)
self.init_db()
self._initialized = True
def init_db(self) -> bool:
"""初始化数据库和表结构
Returns:
bool: 初始化是否成功
"""
try:
# 若数据库文件存在,则删除原有数据库
if os.path.exists(self.db_file):
self.clean_db()
# 创建任务表
return self.task_repository.create_table()
except Exception as e:
print(f"[Error] 初始化数据库失败: {e}")
return False
def clean_db(self) -> bool:
"""清理数据库文件
Returns:
bool: 清理是否成功
"""
return self.db_manager.clean_database()
def insert_sample_tasks(self) -> int:
"""插入示例任务数据
Returns:
int: 成功插入的任务数量
"""
print("\n--- 插入示例任务数据 ---")
tasks = [
Task(title="任务1", description="这是任务1的描述", priority=1),
Task(title="任务2", description="这是任务2的描述", priority=2),
Task(title="任务3", description="这是任务3的描述", priority=3),
Task(title="任务4", description="这是任务4的描述", priority=4),
Task(title="任务5", description="这是任务5的描述", priority=5),
]
return self.task_repository.insert_many(tasks)
def insert_single_task(self, task: Task) -> Optional[int]:
"""插入单个任务
Args:
task: 要插入的任务对象
Returns:
Optional[int]: 插入的任务ID,失败则返回None
"""
return self.task_repository.insert_task(task)
def find_all_tasks(self) -> List[Task]:
"""查询所有任务
Returns:
List[Task]: 任务列表
"""
return self.task_repository.find_all_tasks()
def find_task_by_id(self, task_id: int) -> Optional[Task]:
"""根据ID查询任务
Args:
task_id: 任务ID
Returns:
Optional[Task]: 任务对象,未找到则返回None
"""
return self.task_repository.find_task_by_id(task_id)
def find_tasks_by_criteria(self, criteria: dict) -> List[Task]:
"""根据条件查询任务
Args:
criteria: 查询条件字典
Returns:
List[Task]: 符合条件的任务列表
"""
return self.task_repository.find_by_criteria(criteria)
def find_tasks_by_title(self, title_part: str) -> List[Task]:
"""根据标题模糊查询任务
Args:
title_part: 标题包含的文本
Returns:
List[Task]: 符合条件的任务列表
"""
return self.task_repository.find_by_title_contains(title_part)
def update_task(self, task: Task) -> bool:
"""更新任务
Args:
task: 要更新的任务对象
Returns:
bool: 更新是否成功
"""
return self.task_repository.update_task(task)
def delete_task(self, task_id: int) -> bool:
"""删除任务
Args:
task_id: 要删除的任务ID
Returns:
bool: 删除是否成功
"""
return self.task_repository.delete_task(task_id)
def close(self) -> None:
"""关闭数据库连接"""
self.db_manager.close()
if __name__ == '__main__':
# 实例化示例类
print('\033[1;36m===== 初始化 SQLite 示例 =====\033[0m')
sqlite_example = SQLiteExample()
# 插入示例任务数据
print('\033[1;32m===== 插入示例任务数据 =====\033[0m')
sqlite_example.insert_sample_tasks()
# 查询所有任务
print('\033[1;34m===== 查询所有任务 =====\033[0m')
tasks = sqlite_example.find_all_tasks()
for task in tasks:
print('\033[0;37m' + str(task.to_dict()) + '\033[0m')
# 查询单个任务
print('\033[1;34m===== 查询单个任务 (ID=1) =====\033[0m')
task = sqlite_example.find_task_by_id(1)
print('\033[0;33m' + str(task.to_dict()) + '\033[0m')
# 查询任务列表
print('\033[1;34m===== 查询高优先级任务 (priority=3) =====\033[0m')
tasks = sqlite_example.find_tasks_by_criteria({'priority': 3})
for task in tasks:
print('\033[0;35m' + str(task.to_dict()) + '\033[0m')
# 查询任务标题包含某些字符的任务
print('\033[1;34m===== 查询标题包含"任务"的任务 =====\033[0m')
tasks = sqlite_example.find_tasks_by_title('任务')
for task in tasks:
print('\033[0;36m' + str(task.to_dict()) + '\033[0m')
# 更新任务
print('\033[1;33m===== 更新任务 (ID=1) =====\033[0m')
task = Task(task_id=1, title="任务1-更新", description="这是任务1的更新描述", priority=1, is_completed=True)
sqlite_example.update_task(task)
# 查询更新后的任务
print('\033[1;34m===== 查询更新后的任务 (ID=1) =====\033[0m')
task = sqlite_example.find_task_by_id(1)
print('\033[0;32m' + str(task.to_dict()) + '\033[0m')
# 删除任务
print('\033[1;31m===== 删除任务 (ID=1) =====\033[0m')
sqlite_example.delete_task(1)
# 查询删除后的任务
print('\033[1;34m===== 查询删除后的任务 (ID=1) =====\033[0m')
task = sqlite_example.find_task_by_id(1)
print('\033[0;31m' + str(task) + '\033[0m')
# 关闭数据库连接
print('\033[1;36m===== 关闭数据库连接 =====\033[0m')
sqlite_example.close()
代码解释 (basic_operations.py):
TaskRepository
实例 (task_repo
) 的方法 (insert
, insert_many
, find_all
, find_by_id
, update
, find_by_criteria
, find_by_title_contains
, delete
) 来完成数据库操作。Task
dataclass 对象用于插入和更新,接收 Task
对象列表作为查询结果。db_manager.connect()
和 db_manager.close()
来展示 Repository 如何接收 conn
和 cursor
。在实际应用或更复杂的示例(如下面的事务示例)中,使用 with
语句通常是更好的选择。try...except...finally
来捕获可能的数据库错误或其他异常,并确保连接最终被关闭。examples/advanced_queries.py
)此脚本展示了更复杂的查询场景,包括直接执行 SQL 和通过框架的 Service/Repository 层进行查询。
# sqlite_practice/examples/advanced_queries_oop.py
"""
高级查询示例模块 (面向对象重构版)。
使用类封装演示流程,遵循面向对象原则,使用彩色输出。
展示SQLite数据库的高级查询功能,通过调用分层框架实现。
"""
import sys
import os
import sqlite3
from datetime import datetime, timedelta # 导入 datetime/timedelta
from typing import Optional, Any, List # 引入类型提示
# --- 导入框架组件 ---
from core.db_manager import DatabaseManager
from models.task import Task # Task dataclass
from repositories.task_repository import TaskRepository
from services.task_service import TaskService
# --- 辅助类:用于彩色/样式化打印 ---
class Colors:
"""定义 ANSI 转义码常量用于彩色输出"""
HEADER = '\033[95m' # 紫色 (用于标题)
BLUE = '\033[94m' # 蓝色 (用于 SQL 或代码)
CYAN = '\033[96m' # 青色 (用于提示信息)
GREEN = '\033[92m' # 绿色 (用于成功信息)
WARNING = '\033[93m' # 黄色 (用于警告)
FAIL = '\033[91m' # 红色 (用于错误)
BOLD = '\033[1m' # 加粗
UNDERLINE = '\033[4m' # 下划线
END = '\033[0m' # 重置所有格式
# --- 打印辅助函数 ---
def print_header(text: str):
"""打印带样式的标题"""
print(f"\n{Colors.HEADER}{Colors.BOLD}--- {text} ---{Colors.END}")
def print_subheader(text: str):
"""打印带样式的子标题"""
print(f"\n{Colors.CYAN}{Colors.UNDERLINE} {text}{Colors.END}")
def print_info(text: str):
"""打印普通信息"""
print(f" {text}")
def print_success(text: str):
"""打印成功信息"""
print(f"{Colors.GREEN} ✔ {text}{Colors.END}")
def print_warning(text: str):
"""打印警告信息"""
print(f"{Colors.WARNING} ⚠️ [Warning] {text}{Colors.END}")
def print_error(text: str):
"""打印错误信息"""
print(f"{Colors.FAIL} ❌ [Error] {text}{Colors.END}")
def print_sql(sql: str):
"""打印格式化的 SQL 语句"""
print(f"{Colors.BLUE} SQL: {sql.strip()}{Colors.END}")
def print_result_item(item: Any, indent: int = 4):
"""打印格式化的查询结果项"""
prefix = " " * indent
if isinstance(item, Task):
# 使用 Task 的 __repr__ (dataclass 自动生成)
print(f"{prefix}{item}")
elif hasattr(item, 'keys'): # 支持 sqlite3.Row 或字典
details = ", ".join([f"{Colors.BOLD}{key}{Colors.END}: {item[key]}" for key in item.keys()])
print(f"{prefix}Row({details})")
else:
print(f"{prefix}{item}")
# --- 主演示类 ---
class AdvancedQueryDemoOO:
"""
面向对象的 SQLite 高级查询演示类。
封装了演示的设置、执行和清理逻辑。
"""
# 类属性定义数据库文件路径
DB_FILE = "advanced_queries_oop.db"
def __init__(self):
"""初始化演示类实例"""
print_header("初始化高级查询演示 (OOP)")
# 创建 DatabaseManager 实例,但不立即连接
self.db_manager = DatabaseManager(self.DB_FILE)
# 初始化依赖对象为 None
self.conn: Optional[sqlite3.Connection] = None
self.cursor: Optional[sqlite3.Cursor] = None
self.task_repo: Optional[TaskRepository] = None
self.task_service: Optional[TaskService] = None
# 执行数据库清理
self._cleanup_db()
def _cleanup_db(self):
"""(私有方法) 清理旧的数据库文件"""
if os.path.exists(self.DB_FILE):
try:
os.remove(self.DB_FILE)
print_info(f"旧数据库文件 '{self.DB_FILE}' 已清理。")
except OSError as e:
print_error(f"清理数据库文件 '{self.DB_FILE}' 失败: {e}")
# 清理失败可能导致后续问题,可以选择退出或继续
# sys.exit(1)
def _setup_database_and_framework(self):
"""(私有方法) 建立数据库连接,初始化框架组件,并填充测试数据"""
print_subheader("1. 设置数据库和框架组件")
# self.db_manager 在 __init__ 中已创建
# 使用 with 语句管理连接
# 注意:这里我们将 conn 和 cursor 存储在 self 上,以便后续方法使用
# 这在简单演示类中可行,但在多线程或长期运行应用中需谨慎管理状态
self.conn, self.cursor = self.db_manager.connect() # 直接调用 connect
if not self.conn or not self.cursor:
raise ConnectionError("未能获取数据库连接或游标")
# 初始化 Repository (依赖 conn, cursor)
# Repository 初始化时会尝试创建表
self.task_repo = TaskRepository(self.conn, self.cursor)
# 初始化 Service (依赖 task_repo)
self.task_service = TaskService(self.task_repo)
print_success("Repository 和 Service 初始化成功。")
# 新建数据表
self.task_repo.create_table()
# 填充测试数据
if not self._populate_test_data():
raise RuntimeError("填充测试数据失败。") # 如果填充失败,抛出异常
def get_future_date(self, days: int, format_str: str = '%Y-%m-%d') -> str:
"""获取未来日期字符串
Args:
days: 向后推的天数
format_str: 日期格式,默认为 'YYYY-MM-DD'
Returns:
str: 未来日期字符串
"""
future_date = datetime.now() + timedelta(days=days)
return future_date.strftime(format_str)
def _populate_test_data(self) -> bool:
"""(私有方法) 向数据库填充测试数据"""
if not self.task_repo: return False # 防御性检查
print_subheader("2. 填充测试数据")
today = datetime.now().strftime('%Y-%m-%d')
tomorrow = self.get_future_date(1)
next_week = self.get_future_date(7)
next_month = self.get_future_date(30)
yesterday = self.get_future_date(-1)
test_tasks = [
Task(title="完成项目设计文档", description="项目需求和架构设计", priority=1, due_date=tomorrow),
Task(title="修复关键Bug", description="修复用户报告的登录问题", priority=1, due_date=today),
Task(title="制定项目计划", description="项目时间线和里程碑规划", priority=1, due_date=next_week),
Task(title="紧急会议准备", priority=1, due_date=today, is_completed=False),
Task(title="代码重构", description="重构认证模块的代码", priority=2, due_date=next_week),
Task(title="单元测试编写", description="为核心功能编写单元测试", priority=2, due_date=next_week),
Task(title="API文档更新", description="更新REST API文档", priority=2, due_date=next_month),
Task(title="中期报告撰写", priority=2, due_date=self.get_future_date(10)),
Task(title="性能优化", description="优化数据库查询性能", priority=3, due_date=next_month),
Task(title="学习新技术", description="学习 GraphQL 技术", priority=3, due_date=None),
Task(title="代码审查", description="审查团队成员的代码", priority=3, due_date=next_week),
Task(title="环境搭建", description="搭建开发环境", priority=1, is_completed=True, due_date=yesterday),
Task(title="需求分析", description="分析用户需求", priority=2, is_completed=True, due_date=today),
Task(title="旧项目归档", priority=3, is_completed=True, due_date=None)
]
inserted_count = self.task_repo.insert_many(test_tasks)
if inserted_count is not None:
print_success(
f"已尝试插入 {len(test_tasks)} 条测试数据,成功 {inserted_count if inserted_count != -1 else len(test_tasks)} 条。")
return True
else:
print_error("填充测试数据过程中发生错误。")
return False
def _demonstrate_raw_sql(self):
"""(私有方法) 演示直接执行原生 SQL 查询"""
print_header("演示: 直接执行原生 SQL 查询")
if not self.cursor:
print_error("数据库游标无效。")
return
try:
print_subheader("1. 按优先级统计任务数量 (GROUP BY, COUNT)")
self.cursor.execute(
f"SELECT priority, COUNT(*) as task_count FROM {self.task_repo.table_name} GROUP BY priority ORDER BY priority")
for row in self.cursor.fetchall(): print_result_item(row)
print_subheader("2. 连接查询示例 (LEFT JOIN 语法演示)")
join_sql = f"SELECT t.task_id, t.title, u.username FROM {self.task_repo.table_name} t LEFT JOIN users u ON t.user_id = u.user_id WHERE t.priority = 1 LIMIT 5"
print_sql(join_sql + " (仅演示语法, users 表不存在)")
print_subheader("3. 使用 CASE 表达式统计任务状态")
self.cursor.execute(
f"SELECT SUM(CASE WHEN is_completed = 1 THEN 1 ELSE 0 END) as completed_count, SUM(CASE WHEN is_completed = 0 THEN 1 ELSE 0 END) as pending_count, COUNT(*) as total_count FROM {self.task_repo.table_name}")
stats = self.cursor.fetchone()
if stats: print_result_item(stats)
print_subheader("4. 复杂条件查询: 一周内到期的高优先级(<=2)未完成任务")
self.cursor.execute(
f"SELECT task_id, title, priority, due_date FROM {self.task_repo.table_name} WHERE is_completed = 0 AND priority <= ? AND (due_date IS NOT NULL AND due_date <= date('now', '+7 days')) ORDER BY priority ASC, due_date ASC",
(2,))
urgent_tasks = self.cursor.fetchall()
print_info(f"找到 {len(urgent_tasks)} 条:")
for task in urgent_tasks: print_result_item(task)
print_subheader("5. 分页查询: 获取第 2 页数据 (每页 3 条)")
page_size, page_number = 3, 2
offset = (page_number - 1) * page_size
self.cursor.execute(
f"SELECT task_id, title FROM {self.task_repo.table_name} ORDER BY task_id LIMIT ? OFFSET ?",
(page_size, offset))
page_tasks = self.cursor.fetchall()
print_info(f"第 {page_number} 页 (每页 {page_size} 条):")
for task in page_tasks: print_result_item(task)
print_subheader("6. 子查询: 查找所有今天到期的任务 (使用 EXISTS)")
self.cursor.execute(
f"SELECT task_id, title, due_date FROM {self.task_repo.table_name} t1 WHERE EXISTS (SELECT 1 FROM {self.task_repo.table_name} t2 WHERE t2.task_id = t1.task_id AND t2.due_date = date('now')) ORDER BY priority")
today_tasks = self.cursor.fetchall()
print_info(f"今天 ({datetime.now().strftime('%Y-%m-%d')}) 到期的任务 ({len(today_tasks)} 条):")
for task in today_tasks: print_result_item(task)
except sqlite3.Error as e:
print_error(f"执行原生 SQL 查询时出错: {e}")
def _demonstrate_repository_queries(self):
"""(私有方法) 演示使用 TaskRepository 方法进行查询"""
print_header("演示: 使用 Repository 方法进行查询")
if not self.task_repo:
print_error("TaskRepository 未初始化。")
return
try:
print_subheader("1. 查找优先级为 1 的任务")
priority_1_tasks = self.task_repo.find_by_criteria({"priority": 1})
print_info(f"找到 {len(priority_1_tasks)} 条:")
for task in priority_1_tasks: print_result_item(task)
print_subheader("2. 查找已完成的任务")
completed_tasks = self.task_repo.find_by_criteria({"is_completed": True})
print_info(f"找到 {len(completed_tasks)} 条:")
for task in completed_tasks: print_result_item(task)
print_subheader("3. 搜索标题中包含 '代码' 的任务")
code_tasks = self.task_repo.find_by_title_contains("代码")
print_info(f"找到 {len(code_tasks)} 条:")
for task in code_tasks: print_result_item(task)
print_subheader("4. 组合条件查询 (优先级=2 且 未完成)")
prio2_incomplete = self.task_repo.find_by_criteria({"priority": 2, "is_completed": False})
print_info(f"找到 {len(prio2_incomplete)} 条:")
for task in prio2_incomplete: print_result_item(task)
except Exception as e:
print_error(f"使用 Repository 查询时出错: {e}")
def _demonstrate_service_queries(self):
"""(私有方法) 演示使用 TaskService 方法进行业务查询"""
print_header("演示: 使用 Service 层进行业务查询")
if not self.task_service:
print_error("TaskService 未初始化。")
return
try:
print_subheader("1. 获取未完成的任务")
incomplete_tasks = self.task_service.get_incomplete_tasks()
print_info(f"找到 {len(incomplete_tasks)} 条:")
for task in incomplete_tasks[:3]: print_result_item(task)
if len(incomplete_tasks) > 3: print_info(f" ... 及其他 {len(incomplete_tasks) - 3} 条")
print_subheader("2. 获取未来 7 天内到期的任务")
due_soon_tasks = self.task_service.get_tasks_due_within_days(7)
print_info(f"找到 {len(due_soon_tasks)} 条:")
for task in due_soon_tasks: print_result_item(task)
print_subheader("3. 获取已逾期的未完成任务")
overdue_tasks = self.task_service.get_overdue_tasks()
print_info(f"找到 {len(overdue_tasks)} 条:")
for task in overdue_tasks: print_result_item(task)
except Exception as e:
print_error(f"使用 Service 查询时出错: {e}")
def run(self):
"""
执行完整的演示流程: 设置 -> 填充数据 -> 各种查询演示 -> 清理。
使用 try...finally 确保数据库连接总是被关闭。
"""
try:
# --- 步骤 1: 设置数据库和框架 ---
# 使用 with 语句确保连接在 setup 完成后也可用,
# 同时在 run 方法结束时通过 finally 关闭。
# 或者,让 setup 返回 conn/cursor/repo/service
self._setup_database_and_framework()
# --- 步骤 2: 执行各种查询演示 ---
if self.cursor: self._demonstrate_raw_sql()
if self.task_repo: self._demonstrate_repository_queries()
if self.task_service: self._demonstrate_service_queries()
print_success("\n所有查询演示执行完毕。")
except ConnectionError as ce:
print_error(f"数据库连接错误: {ce}")
except RuntimeError as rte: # 捕获 setup 中可能抛出的错误
print_error(f"运行时错误 (可能在设置或填充数据时): {rte}")
except sqlite3.Error as db_err:
print_error(f"数据库操作错误: {db_err}")
import traceback
traceback.print_exc()
except Exception as e:
print_error(f"发生意外错误: {e}")
import traceback
traceback.print_exc()
finally:
# --- 步骤 3: 清理 ---
# 无论 run 方法中发生什么,都确保关闭连接
if self.db_manager:
self.db_manager.close()
print_header("高级查询演示流程结束")
# --- 脚本入口 ---
if __name__ == "__main__":
# 1. 创建演示类的实例
demo = AdvancedQueryDemoOO()
# 2. 运行演示
demo.run()
代码解释 (advanced_queries.py):
populate_test_data
: 创建了多样化的任务数据,包括不同的优先级、截止日期和完成状态,为后续复杂查询提供基础。run_raw_sql_queries
: 直接使用 cursor
执行 SQL。这展示了:
COUNT
, SUM
) 和 GROUP BY
: 用于统计分析。CASE
表达式: 在 SQL 中实现条件逻辑。WHERE
子句: 结合 AND
, OR
, 比较运算符, IS NOT NULL
和 SQLite 内建日期函数 (date('now', '+7 days')
)。ORDER BY
多列排序。LIMIT
和 OFFSET
: 实现分页。EXISTS
子查询: 进行基于关联的条件判断。run_repository_queries
: 使用 TaskRepository
提供的 find_by_criteria
和 find_by_title_contains
方法。展示了通过抽象接口进行条件查询。run_service_queries
: 使用 TaskService
提供的面向业务的方法,如 get_incomplete_tasks
, get_tasks_due_within_days
, get_overdue_tasks
。这体现了 Service 层封装业务逻辑查询的优势。run_advanced_queries_main
: 主函数负责初始化框架组件(使用 with DatabaseManager
),调用数据填充函数,然后依次调用三种查询方式的演示函数。PyMySQL
驱动)16.2.1 简介
PyMySQL
前,你需要确保:
pip install PyMySQL
sqlite3
主要不同:
%s
。utf8mb4
) 和游标类型 (DictCursor
) 的设置。16.2.2 建立连接
连接是与 MySQL 服务器交互的第一步。
我们把上一章的打印工具给拿过来,放置在根目录print_utils
"""
打印工具模块,提供彩色和结构化的打印函数。
"""
# ======== 彩色打印工具 ========
class Colors:
HEADER = '\033[95m'
BLUE = '\033[94m'
CYAN = '\033[96m'
GREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
END = '\033[0m'
def print_header(text: str):
print(f"\n{Colors.HEADER}{Colors.BOLD}--- {text} ---{Colors.END}")
def print_subheader(text: str):
print(f"\n{Colors.CYAN}{Colors.UNDERLINE} {text}{Colors.END}")
def print_info(text: str):
print(f" {text}")
def print_success(text: str):
print(f"{Colors.GREEN} ✔ {text}{Colors.END}")
def print_warning(text: str):
print(f"{Colors.WARNING} ⚠️ [Warning] {text}{Colors.END}")
def print_error(text: str):
print(f"{Colors.FAIL} ❌ [Error] {text}{Colors.END}")
def print_sql(sql: str):
print(f"{Colors.BLUE} SQL: {sql.strip()}{Colors.END}")
def print_result_item(item, indent: int = 4):
prefix = " " * indent
if isinstance(item, dict):
details = ", ".join([
f"{Colors.BOLD}{key}{Colors.END}: {repr(value)}" for key, value in item.items()
])
print(f"{prefix}Row({details})")
else:
print(f"{prefix}{repr(item)}")
# ======== END 彩色打印工具 ========
当我们需要使用时,直接from导入
并像调用函数一样调用即可
# --- PyMySQL 连接示例 ---
import pymysql
from pymysql.cursors import DictCursor # 导入字典游标类
from typing import Optional, Tuple, Dict, List, Any # 类型提示
from sqlite_practice.utils.print_utils import (
Colors, print_header, print_subheader, print_info, print_success, print_warning, print_error, print_sql, print_result_item
)
# --- 数据库连接配置 ---
# !!! 警告: 生产环境中绝不应硬编码密码! 应使用更安全的方式管理凭证 !!!
DB_CONFIG: Dict[str, Any] = {
'host': 'localhost', # MySQL 服务器主机名或 IP 地址
'port': 3306, # MySQL 默认端口
'user': 'root', # 替换为你的 MySQL 用户名
'password': 'root', # 替换为你的 MySQL 密码
'database': 'pymysql_demo_db', # 替换为你要连接的数据库名 (需预先创建)
'charset': 'utf8mb4', # !!! 强烈推荐: 支持完整 Unicode, 包括 emoji !!!
'cursorclass': DictCursor, # !!! 强烈推荐: 让查询结果以字典形式返回 !!!
'autocommit': False, # !!! 推荐: 关闭自动提交,显式控制事务 !!!
'connect_timeout': 10 # (可选) 连接超时时间 (秒)
}
# --- 获取数据库连接的函数 ---
def get_mysql_connection() -> Optional[pymysql.connections.Connection]:
"""尝试建立到 MySQL 数据库的连接。"""
print_info(f"尝试连接到 MySQL: {DB_CONFIG['host']}:{DB_CONFIG['port']}, Database: {DB_CONFIG['database']}")
try:
connection = pymysql.connect(**DB_CONFIG)
print_success(f"成功连接到数据库 '{DB_CONFIG['database']}'")
return connection
except pymysql.Error as e:
print_error(f"数据库连接失败: {e}")
print_warning(" 请检查 DB_CONFIG 配置、MySQL 服务器状态及网络连接。")
return None
连接参数关键点:
charset='utf8mb4'
: 确保可以正确处理各种字符,包括表情符号。数据库、表、列也应使用此字符集。cursorclass=DictCursor
: 使 cursor.fetchone()
返回字典,cursor.fetchall()
返回字典列表。可通过 row['column_name']
访问数据,比元组索引更易读。autocommit=False
: 禁用自动提交。所有 INSERT
, UPDATE
, DELETE
操作都需要显式调用 conn.commit()
才生效,便于事务管理。16.2.3 核心 API 概览
PyMySQL
遵循 DB-API 2.0 规范,核心 API 与 sqlite3
类似,但注意占位符和游标类型。
API | 描述 | 注意事项 |
---|---|---|
pymysql.connect(**config) |
连接 MySQL,返回 Connection 对象。 |
参数多,需网络 |
connection.cursor([cursorclass]) |
创建 Cursor 对象 (推荐 DictCursor )。 |
DictCursor 返回字典 |
cursor.execute(sql, [args]) |
执行单条 SQL (用 %s 占位符)。返回受影响行数。 |
%s 占位符 |
cursor.executemany(sql, seq_of_args) |
批量执行 SQL (用 %s )。返回受影响总行数 (可能不精确)。 |
%s 占位符 |
connection.commit() |
提交当前事务。 | autocommit=False 时必需 |
connection.rollback() |
回滚当前事务。 | |
connection.begin() |
显式开始事务 (若 autocommit=False )。 |
可选,但更清晰 |
cursor.fetchone() |
获取下一行结果 (字典或元组),无结果时 None 。 |
依赖 cursorclass |
cursor.fetchall() |
获取所有剩余行结果 (字典或元组的列表)。 | 依赖 cursorclass |
cursor.lastrowid |
(属性) 最后 INSERT 操作的自增 ID (AUTO_INCREMENT 列)。 |
|
cursor.rowcount |
(属性) 最后操作影响的行数 (对 SELECT 可能不同数据库行为有差异)。 | |
connection.close() |
必须调用关闭网络连接。 | 极其重要 |
with connection.cursor() as cursor: |
(推荐) 游标上下文管理器,自动关闭游标。 | |
connection.ping(reconnect=True) |
检查连接活性,可选重连。 | 用于长连接场景 |
16.2.4 准备工作:创建表与填充数据
为了演示查询,我们创建 categories
和 products
表。
categories 表
列名 | 数据类型 | 是否允许 NULL | 键 | 默认值 | 额外 | 注释 |
---|---|---|---|---|---|---|
category_id | INT UNSIGNED | NO | PRI | — | AUTO_INCREMENT | 产品类别主键 |
name | VARCHAR(50) | NO | UNI | — | 类别名称 | |
description | TEXT | YES | NULL | 类别描述 |
products 表
列名 | 数据类型 | 是否允许 NULL | 键 | 默认值 | 额外 | 注释 |
---|---|---|---|---|---|---|
id | INT UNSIGNED | NO | PRI | — | AUTO_INCREMENT | 产品ID |
name | VARCHAR(100) | NO | — | 产品名称 | ||
price | DECIMAL(10, 2) | NO | — | 价格 | ||
stock | INT UNSIGNED | NO | 0 | 库存 | ||
category_id | INT UNSIGNED | YES | MUL | NULL | 外键,关联 categories.category_id | |
added_date | DATE | YES | NULL | 上架日期 | ||
created_at | TIMESTAMP | YES | CURRENT_TIMESTAMP | 创建时间 | ||
updated_at | TIMESTAMP | YES | CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP | 更新时间 |
这两个函数仅是一次性的,用于创建表与插入数据,可以看看语法,也可以不看,简单的SQL而已
# --- PyMySQL 连接示例 ---
import pymysql
from pymysql.cursors import DictCursor # 导入字典游标类
from typing import Optional, Tuple, Dict, List, Any # 类型提示
from print_utils import (
Colors, print_header, print_subheader, print_info, print_success, print_warning, print_error, print_sql, print_result_item
)
# --- 数据库连接配置 ---
# !!! 警告: 生产环境中绝不应硬编码密码! 应使用更安全的方式管理凭证 !!!
DB_CONFIG: Dict[str, Any] = {
'host': 'localhost', # MySQL 服务器主机名或 IP 地址
'port': 3306, # MySQL 默认端口
'user': 'root', # 替换为你的 MySQL 用户名
'password': 'root', # 替换为你的 MySQL 密码
'database': 'pymysql_demo_db', # 替换为你要连接的数据库名 (需预先创建)
'charset': 'utf8mb4', # !!! 强烈推荐: 支持完整 Unicode, 包括 emoji !!!
'cursorclass': DictCursor, # !!! 强烈推荐: 让查询结果以字典形式返回 !!!
'autocommit': False, # !!! 推荐: 关闭自动提交,显式控制事务 !!!
'connect_timeout': 10 # (可选) 连接超时时间 (秒)
}
def get_mysql_connection() -> Optional[pymysql.Connection]:
"""创建 MySQL 连接"""
print_info(f"尝试连接到 MySQL: {DB_CONFIG['host']}:{DB_CONFIG['port']}, Database: {DB_CONFIG['database']}")
try:
connection = pymysql.connect(**DB_CONFIG)
print_success("连接成功")
return connection
except pymysql.Error as e:
print_error(f"数据库连接失败: {e}")
print_warning(" 请检查 DB_CONFIG 配置、MySQL 服务器状态及网络连接。")
return None
# --- 表结构定义与创建 ---
def setup_mysql_tables(conn: pymysql.connections.Connection) -> bool:
"""创建演示所需的 categories 和 products 表。"""
try:
with conn.cursor() as cursor:
print_info("开始创建表...")
cursor.execute("DROP TABLE IF EXISTS products;") # 先删除依赖表
cursor.execute("DROP TABLE IF EXISTS categories;") # 再删除主表
print_warning(" (旧表 'products' 和 'categories' 已删除,如果存在)")
# --- 创建 categories 表 ---
cursor.execute("""
CREATE TABLE categories (
category_id INT UNSIGNED AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(50) NOT NULL UNIQUE COMMENT '类别名称',
description TEXT COMMENT '类别描述'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='产品类别表';
""")
print_success("表 'categories' 创建成功。")
# --- 创建 products 表 (包含外键) ---
cursor.execute("""
CREATE TABLE products (
id INT UNSIGNED AUTO_INCREMENT PRIMARY KEY COMMENT '产品ID',
name VARCHAR(100) NOT NULL COMMENT '产品名称',
price DECIMAL(10, 2) NOT NULL COMMENT '价格',
stock INT UNSIGNED NOT NULL DEFAULT 0 COMMENT '库存',
category_id INT UNSIGNED COMMENT '外键, 关联 categories 表',
added_date DATE COMMENT '上架日期',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
INDEX idx_prod_name (name(20)), -- 名称前缀索引
FOREIGN KEY fk_prod_cat (category_id) REFERENCES categories(category_id)
ON DELETE SET NULL -- 删除类别时,产品类别设为 NULL
ON UPDATE CASCADE -- 更新类别 ID 时,产品自动更新
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='产品信息表';
""")
print_success("表 'products' 创建成功。")
conn.commit() # 提交 DDL 更改
return True
except pymysql.Error as e:
print_error(f"创建表失败: {e}")
conn.rollback()
return False
# --- 数据填充 ---
def populate_mysql_data(conn: pymysql.connections.Connection) -> bool:
"""向 categories 和 products 表填充测试数据。"""
try:
with conn.cursor() as cursor:
print_info("开始填充测试数据...")
# --- 插入类别 ---
categories_data = [
('电子产品', '手机、电脑、配件等'), ('图书音像', '各类实体书籍和数字媒体'),
('服装鞋包', '男女服装、鞋子和箱包'), ('家居生活', '家具、厨具、家纺等'),
]
cat_sql = "INSERT INTO categories (name, description) VALUES (%s, %s)"
cursor.executemany(cat_sql, categories_data)
print_success(f" 插入了 {cursor.rowcount} 个类别。")
# --- 获取类别 ID 用于关联 ---
# 注意:更可靠的方式是按名称查询获取 ID,这里简化
cursor.execute("SELECT category_id, name FROM categories")
cat_map = {row['name']: row['category_id'] for row in cursor.fetchall()}
# --- 插入产品 ---
products_data = [
('智能手机 V12', 4999.00, 50, cat_map['电子产品'], '2024-10-01'),
('蓝牙耳机 AirSound', 799.00, 100, cat_map['电子产品'], '2024-11-15'),
('PyMySQL 深度指南', 88.50, 200, cat_map['图书音像'], '2023-05-20'),
('算法导论 (原版)', 128.00, 80, cat_map['图书音像'], '2024-01-10'),
('纯棉印花 T 恤', 129.00, 300, cat_map['服装鞋包'], '2024-08-01'),
('透气运动跑鞋', 499.00, 60, cat_map['服装鞋包'], '2024-09-01'),
('北欧风实木餐桌', 2899.00, 10, cat_map['家居生活'], '2023-12-01'),
('乳胶记忆棉枕头', 159.00, 150, cat_map['家居生活'], '2024-03-01'),
('游戏本 RTX 9000', 12999.00, 20, cat_map['电子产品'], None),
('科幻短篇小说集', 65.00, 120, cat_map['图书音像'], '2024-06-15'),
('速干运动短裤', 189.00, 100, None, '2024-07-01'), # 无类别
('咖啡机', 899.00, 30, cat_map['家居生活'], '2024-05-01'),
('SQL 查询的艺术', 75.00, 90, cat_map['图书音像'], '2024-04-15')
]
prod_sql = "INSERT INTO products (name, price, stock, category_id, added_date) VALUES (%s, %s, %s, %s, %s)"
cursor.executemany(prod_sql, products_data)
print_success(f" 插入了 {cursor.rowcount} 个产品。")
conn.commit() # 提交所有插入
print_success("测试数据填充完成。")
return True
except pymysql.Error as e:
print_error(f"填充数据失败: {e}")
conn.rollback()
return False
if __name__ == '__main__':
# 1.连接数据库
conn = get_mysql_connection()
if conn is None:
exit(1)
# 2.创建表结构
if not setup_mysql_tables(conn):
exit(1)
# 3.填充测试数据
if not populate_mysql_data(conn):
exit(1)
# 4.查询数据
with conn.cursor() as cursor:
print_header("查询数据")
def insert_single_product(conn: pymysql.Connection, product_name: str, price: float, stock: int,
category_id: Optional[int] = None, added_date: Optional[str] = None) -> Optional[int]:
"""
插入单个产品记录
:param conn: 数据库连接
:param product_name: 产品名称
:param price: 产品价格
:param stock: 产品库存
:param category_id: 产品分类 ID (可选)
:param added_date: 产品添加日期 (可选)
:return: 新插入的产品 ID (若插入失败则返回 None)
"""
sql = """
INSERT INTO products (name, price, stock, category_id, added_date)
VALUES (%s, %s, %s, %s, %s)
"""
params = (product_name, price, stock, category_id, added_date)
try:
with conn.cursor() as cursor:
cursor.execute(sql, params)
new_id = cursor.lastrowid # 获取自增主键值
conn.commit()
print_success(f"成功插入产品:{product_name} (ID: {new_id})")
return new_id
except pymysql.Error as e:
print_error(f"插入产品失败: {e}")
conn.rollback()
return None
if __name__ == '__main__':
# 1.连接数据库
conn = get_mysql_connection()
if conn is None:
exit(1)
try:
new_id = insert_single_product(conn, product_name='iPhone X', price=9999.99, stock=100, category_id=1,
added_date='2021-01-01')
if new_id is not None:
print_success(f"新插入的产品 ID: {new_id}")
except Exception as e:
print_error(f"插入产品失败: {e}")
executemany
)def insert_multiple_products(conn: pymysql.Connection, products_list: List[Tuple]) -> Optional[List[int]]:
"""
批量插入产品记录
:param conn: 数据库连接
:param products_list: 产品列表 (列表项为元组,包含 (name, price, stock, category_id, added_date) 五个字段)
:return: 新插入的产品 ID 列表 (若插入失败则返回 None)
"""
if not products_list: return 0
sql = "INSERT INTO products (name,price,stock,category_id,added_date,created_at) VALUES (%s,%s,%s,%s,%s,NOW())"
try:
with conn.cursor() as cursor:
cursor.executemany(sql, products_list)
new_ids = [cursor.lastrowid for _ in range(len(products_list))] # 获取自增主键值
conn.commit()
print_success(f"成功插入 {len(products_list)} 个产品")
return new_ids
except pymysql.Error as e:
print_error(f"插入产品失败: {e}")
conn.rollback()
return None
# --- 使用 ---
if __name__ == '__main__':
# 1.创建数据库连接
conn = get_mysql_connection()
if not conn:
exit(1)
# 演示插入多个产品
products_list = [
('iPhone X', 9999, 100, 1, '2021-01-01'),
('华为 P30 Pro', 8888, 50, 2, '2021-01-02'),
('小米 10', 7777, 20, 3, '2021-01-03'),
('OPPO Find X3', 6666, 10, 4, '2021-01-04'),
('vivo NEX', 5555, 5, 4, '2021-01-05'),
]
new_ids = insert_multiple_products(conn, products_list)
if new_ids:
print_info(f"新插入的产品 ID 列表: {new_ids}")
注意,在使用的时候category_id 、 时不能存储4以上的,由于我们上面插入的insert外键最高为4,若需要更新为4类别的,则需要增加category_id的数量
def update_product_stock(conn: pymysql.Connection, product_id: int, stock: int) -> bool:
"""更新指定产品的库存 (增加或减少)。"""
# 注意:在真实应用中,直接加减可能导致并发问题,应使用更安全的更新方式
# 例如:UPDATE products SET stock = stock + %s WHERE id = %s AND stock + %s >= 0
# 这里仅作基本 UPDATE 演示
sql_get_stock = "SELECT stock FROM products WHERE id = %s"
sql_update = "UPDATE products SET stock = %s WHERE id = %s"
try:
with conn.cursor() as cursor:
# 1.获取当前库存
cursor.execute(sql_get_stock, (product_id,))
result = cursor.fetchone()
if not result:
print_warning(f"更新库存失败: 未找到产品 ID 为 {product_id} 的记录")
return False
current_stock = result["stock"]
# 2.计算新库存
new_stock = current_stock + stock
if new_stock < 0:
print_warning(f"更新库存失败: 库存不足 (当前库存: {current_stock}, 增加数量: {stock})")
return False
# 3.执行更新
affected_rows = cursor.execute(sql_update, (new_stock, product_id))
conn.commit()
if affected_rows > 0:
print_success(f"成功更新产品 ID 为 {product_id} 的库存为 {new_stock}")
return True
else:
# rowcount 为 0 可能是 ID 不存在或值未改变
print_warning(f"更新库存失败: 未找到影响行数 (affected_rows: {affected_rows})")
return False
except pymysql.Error as e:
print_error(f"更新库存失败: {e}")
conn.rollback()
return False
if __name__ == '__main__':
# 1.创建数据库连接
conn = get_mysql_connection()
if not conn:
exit(1)
# 2.更新商品库存
product_id = 1
stock = 10
if update_product_stock(conn, product_id, stock):
print_info(f"库存更新成功")
else:
print_error(f"库存更新失败")
def delete_product_by_id(conn: pymysql.connections.Connection, product_id: int) -> bool:
"""根据 ID 删除产品。"""
sql = "DELETE FROM products WHERE id = %s"
try:
with conn.cursor() as cursor:
affected_rows = cursor.execute(sql, (product_id,))
conn.commit()
if affected_rows > 0:
print_success(f"成功删除产品 ID={product_id}")
return True
else:
print_warning(f"删除产品 ID={product_id} 时未找到匹配记录。")
return False
except pymysql.Error as e:
print_error(f"删除产品 ID={product_id} 失败: {e}")
conn.rollback()
return False
# --- 使用 ---
if __name__ == '__main__':
# 1.创建数据库连接
conn = get_mysql_connection()
if not conn:
exit(1)
# 测试删除产品
affected_rows = delete_product_by_id(conn, 1)
if affected_rows:
print_success("删除产品成功")
else:
print_warning("删除产品失败")
*
)def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_header("演示 SELECT 查询")
try:
with conn.cursor() as cursor:
cursor.execute("SELECT * FROM products") # 核心知识点
result = cursor.fetchall() # 获取全部结果集
print_info(f"查询结果: {len(result)}条结果")
for row in result:print_result_item(row)
except pymysql.Error as e:print_error(f"查询失败: {e}")
def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("查询所有产品的名称和价格 (前 5 条)")
try:
with conn.cursor() as cursor:
cursor.execute("SELECT name, price FROM products LIMIT 5") # 核心知识点 # 核心知识点
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
WHERE
)def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("查询商品分类 ID 为 2的产品")
try:
category_id_to_find = 2 # 假设要查询分类 ID 为 2 的产品
with conn.cursor() as cursor:
cursor.execute("SELECT id,name,price FROM products Where category_id = %s", (category_id_to_find,)) # 核心知识点
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("查询商品价钱大于100的所商品名和价格")
try:
max_price = 100.0
with conn.cursor() as cursor:
cursor.execute("SELECT name,price FROM productS WHERE price > %s ORDER BY price DESC", (max_price,))
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("查询商品价格在100~500之间的商品,且通过商品价格降序排序")
try:
min_p, max_p = 100.0, 500.0
with conn.cursor() as cursor:
cursor.execute("SELECT name, price FROM products WHERE price BETWEEN %s AND %s ORDER BY price", (min_p, max_p))
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("查询类别 ID 为 1 或 4 (电子产品或家居生活) 的产品")
try:
target_cat_ids = (1, 4)
placeholders = ', '.join(['%s'] * len(target_cat_ids)) # 有几个target_cat_ids参数就有几个占位符 %s
with conn.cursor() as cursor:
cursor.execute(f"SELECT id, name, category_id FROM products WHERE category_id IN ({placeholders})",target_cat_ids)
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("查询名称包含 '运动' 的产品")
try:
search_term = "运动"
with conn.cursor() as cursor:
# 注意 % 需要在参数中传递,而不是直接放在 SQL 字符串里
cursor.execute("SELECT name, price FROM products WHERE name LIKE %s", (f"%{search_term}%",))
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("查询没有上架日期的产品")
try:
with conn.cursor() as cursor:
cursor.execute("SELECT id,name FROM products WHERE added_date IS NULL")
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("查询图书音像类(ID=2) 或 价格低于100元 且 库存大于0 的产品")
try:
cat_id = 2
max_p = 100.0
with conn.cursor() as cursor:
cursor.execute("SELECT * FROM products WHERE (category_id = %s OR price < %s) AND stock > 0", (cat_id, max_p))
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
ORDER BY
)def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("按库存数量降序排列产品 (前 5)")
try:
with conn.cursor() as cursor:
cursor.execute("SELECT name,stock FROM products ORDER BY stock DESC LIMIT 5")
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("按类别 ID 升序、价格降序排列产品 (前 5)")
try:
with conn.cursor() as cursor:
# 确保先按 category_id 排序,相同 category_id 的再按 price 降序
cursor.execute("SELECT * FROM products ORDER BY category_id ASC,price DESC LIMIT 5")
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
DISTINCT
)def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("查询现有的所有产品分类的 ID")
try:
with conn.cursor() as cursor:
cursor.execute("SELECT DISTINCT category_id FROM products WHERE category_id IS NOT NULL")
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("计算产品统计信息 (总数, 总库存, 平均价, 最高价, 最低价)")
sql = """
SELECT
COUNT(*) as total_products,
SUM(stock) as total_stock,
AVG(price) as average_price,
MAX(price) as max_price,
MIN(price) as min_price
FROM products
"""
try:
with conn.cursor() as cursor:
cursor.execute(sql)
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
GROUP BY
)def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("按类别名称分组,统计各类别的产品数量和平均价格")
sql = """
SELECT
c.name as category_name, -- 从 categories 表获取分类名称
COUNT(p.id) as num_products, -- 统计每个分类下产品数量
AVG(p.price) as avg_price -- 统计每个分类下产品平均价格
FROM products as p
LEFT JOIN categories as c ON p.category_id = c.category_id -- 左连接 categories 表
GROUP BY c.name -- 按照类别名称分组
ORDER BY num_products DESC -- 按产品数量倒序排列
"""
try:
with conn.cursor() as cursor:
cursor.execute(sql)
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
HAVING
)def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("按类别分组,并找出平均价格高于 3000 元的类别")
sql = """
SELECT
c.name as category_name,
COUNT(p.id) as num_products,
AVG(p.price) as avg_price
FROM
products p
JOIN categories c ON p.category_id = c.category_id -- 使用外键关联分类
GROUP BY c.name
HAVING AVG(p.price) > %s -- 使用HAVING 过滤分组结果
ORDER BY avg_price DESC -- 按平均价格降序排序
"""
try:
MIN_AVG_PRICE = 3000.0
with conn.cursor() as cursor:
cursor.execute(sql, (MIN_AVG_PRICE,))
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
LIMIT
/OFFSET
)print_subheader("分页查询: 获取第 3 页数据 (每页 4 条)")
page = 3
page_size = 4
offset = (page - 1) * page_size
sql = "SELECT id, name, price FROM products ORDER BY id LIMIT %s OFFSET %s"
try:
with conn.cursor() as cursor:
cursor.execute(sql, (page_size, offset))
results = cursor.fetchall()
print_info(f"第 {page} 页 (每页 {page_size} 条),获取 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e: print_error(f"查询失败: {e}")
JOIN
)INNER JOIN
)INNER JOIN: 必须要两边表都有匹配的数据才会出现在结果中。
def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("内连接查询: 获取产品及其对应的类别名称 (只显示有类别的产品)")
sql = """
SELECT p.id,p.name AS product_name , p.price,c.name AS category_name
FROM products p
INNER JOIN categories c ON p.category_id = c.category_id -- 只返回两个表都能匹配上的行
ORDER BY p.name,c.name
LIMIT 5 -- 限制输出
"""
try:
with conn.cursor() as cursor:
cursor.execute(sql)
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
LEFT JOIN
)左连接保留左表所有行(右表无匹配则补空)
def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("左连接查询: 获取所有产品及其类别名称 (没有类别的产品也会显示)")
sql = """
SELECT p.id, p.name AS product_name, p.price, c.name AS category_name
FROM products p
-- 这里换成 RIGHT JOIN 就会少一条记录 (没有类别的产品)
LEFT JOIN categories c ON p.category_id = c.category_id -- 返回左表(products)的所有行
ORDER BY p.id
"""
try:
with conn.cursor() as cursor:
cursor.execute(sql)
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
CASE
表达式def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("使用 CASE 表达式根据价格给产品分类")
sql = """
SELECT
name, price,
CASE
WHEN price < 100 THEN '入门级'
WHEN price BETWEEN 100 AND 999.99 THEN '标准级'
WHEN price BETWEEN 1000 AND 4999.99 THEN '进阶级'
ELSE '旗舰级'
END AS price_level -- 给 CASE 结果起别名
FROM products
ORDER BY price
LIMIT 6 -- 限制输出
"""
try:
with conn.cursor() as cursor:
cursor.execute(sql)
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
def select_demonstration(conn: pymysql.Connection) -> None:
"""演示 SELECT 查询"""
print_subheader("子查询: 查找价格高于平均价格的产品")
sql = """
SELECT ID,name,price
FROM products
WHERE price > (SELECT AVG(price) FROM products)
ORDER BY price DESC
"""
try:
with conn.cursor() as cursor:
cursor.execute(sql)
results = cursor.fetchall()
print_info(f"查询到 {len(results)} 条记录:")
for row in results: print_result_item(row)
except pymysql.Error as e:
print_error(f"查询失败: {e}")
更详细的SQL语法,详见MYSQL篇章
SQLAlchemy Core 提供了一种使用 Python 对象构建 SQL 语句的方式,避免直接拼接字符串,从而提高代码的安全性和可维护性。本节将核心概念和常用 API 通过表格和简洁示例进行介绍。
首先,你需要安装 SQLAlchemy 库:
pip install SQLAlchemy
SQLAlchemy 本身不包含数据库驱动程序 (DBAPI),它依赖于第三方驱动来与具体的数据库进行通信。因此,你还需要根据你使用的数据库安装相应的驱动。SQLAlchemy 通过驱动实现与数据库的连接和交互。
常用数据库及其推荐驱动:
数据库 | SQLAlchemy 连接 URL 方言 | 推荐驱动 (DBAPI) | 安装命令 (示例) |
---|---|---|---|
PostgreSQL | postgresql |
psycopg2 (binary) |
pip install psycopg2-binary |
MySQL | mysql |
mysqlclient |
pip install mysqlclient |
mysql |
PyMySQL |
pip install PyMySQL |
|
SQLite | sqlite |
sqlite3 (内置) |
(无需额外安装) |
Microsoft SQL | mssql |
pyodbc |
pip install pyodbc |
Oracle | oracle |
cx_Oracle |
pip install cx_Oracle |
选择驱动:
psycopg2
是最常用且功能最全的选择。psycopg2-binary
包含了预编译的 C 扩展,安装更方便。mysqlclient
是一个性能较好的 C 扩展驱动,但可能需要编译环境。PyMySQL
是纯 Python 实现,安装简单,兼容性好,在许多场景下性能也足够。SQLAlchemy 连接 URL 中需要指定驱动,如 mysql+pymysql://...
或 mysql+mysqlclient://...
。sqlite3
驱动是 Python 内置的。本章我们还是使用PyMysql作为核心驱动!
Engine
是 SQLAlchemy 应用与数据库交互的起点,负责管理连接池和数据库方言。
创建引擎 (create_engine
)
使用 sqlalchemy.create_engine()
函数创建。
参数 | 类型 | 描述 | 示例 (SQLite 文件) |
---|---|---|---|
url |
str |
必需. 数据库连接 URL,格式为 dialect+driver://user:password@host:port/database 。具体格式见下方说明。 |
"sqlite:///myapp.db" |
echo |
bool |
可选,默认为 False 。设为 True 时,打印 SQLAlchemy 执行的所有 SQL 语句。调试时非常有用。 |
echo=True |
pool_size |
int |
可选,连接池保持的最小连接数。 | pool_size=5 |
max_overflow |
int |
可选,超出 pool_size 后允许额外创建的最大连接数。 |
max_overflow=10 |
pool_recycle |
int |
可选,连接在连接池中保持多少秒后被回收 (防止数据库服务器因超时关闭连接)。 | pool_recycle=3600 (1 小时) |
connect_args |
dict |
可选,传递给底层 DBAPI connect() 方法的额外参数。 |
connect_args={'timeout': 10} (某些驱动) |
execution_options |
dict |
可选,设置执行选项,如 isolation_level (事务隔离级别)。 |
|
json_serializer |
callable |
可选,用于序列化 JSON 数据。 | |
json_deserializer |
callable |
可选,用于反序列化 JSON 数据。 |
数据库连接 URL ( url
) 详解:
dialect[+driver]://[user[:password]@][host][:port]/[database][?key=value&key=value...]
dialect
: 数据库类型 (sqlite
, mysql
, postgresql
, mssql
, oracle
等)。driver
(可选): 使用的 DBAPI 库 (psycopg2
, pymysql
, mysqlclient
, pyodbc
等)。如果省略,SQLAlchemy 会尝试默认驱动。sqlite:///path/to/your/database.db
(注意三个 /
)sqlite:///:memory:
(或者仅 sqlite://
)mysql+pymysql://user:pass@host:3306/dbname?charset=utf8mb4
postgresql+psycopg2://user:pass@host:5432/dbname
# core/engine_examples.py
from sqlalchemy import create_engine
from print_utils import print_header, print_info, print_success, print_error
# --- 创建不同数据库的 Engine 示例 ---
print_header("创建 SQLAlchemy Engine 示例")
try:
# 1. SQLite (文件数据库)
# URL: sqlite:///path/to/database.db (三个斜杠表示相对或绝对路径)
sqlite_file_engine = create_engine("sqlite:///sqlalchemy_core_example.db", echo=True)
# echo=True 会打印 SQLAlchemy 执行的 SQL 语句,非常适合学习和调试
print_success("创建 SQLite 文件数据库引擎成功 (带 SQL 日志)。")
# 可以在这里测试连接 (可选)
# with sqlite_file_engine.connect() as conn:
# print_success("SQLite 文件数据库连接测试成功。")
# 2. SQLite (内存数据库)
# URL: sqlite:///:memory: (四个斜杠或仅 :memory:)
sqlite_memory_engine = create_engine("sqlite:///:memory:", echo=False)
print_success("创建 SQLite 内存数据库引擎成功。")
# 3. PostgreSQL (使用 psycopg2 驱动)
# URL: postgresql+psycopg2://user:password@host:port/database
# 替换为你的实际 PostgreSQL 配置
#pg_url = "postgresql+psycopg2://your_user:your_password@localhost:5432/your_database"
# pg_engine = create_engine(pg_url, echo=True, pool_size=5, max_overflow=10)
# pool_size: 连接池中保持的最小连接数
# max_overflow: 超出 pool_size 后允许临时创建的最大连接数
#print_info(f"PostgreSQL 连接 URL (示例): {pg_url} (请替换为实际配置)")
# print_success("创建 PostgreSQL 引擎配置示例 (pool_size=5, max_overflow=10)。")
# 4. MySQL (使用 PyMySQL 驱动)
# URL: mysql+pymysql://user:password@host:port/database?charset=utf8mb4
# 推荐在 URL 中指定 charset=utf8mb4
mysql_url = "mysql+pymysql://root:root@localhost:3306/pymysql_demo?charset=utf8mb4"
mysql_engine = create_engine(mysql_url, echo=True)
print_success("创建 MySQL 引擎配置示例 (pymysql_demo)。")
except ImportError as ie:
print_error(f"创建引擎失败:缺少必要的数据库驱动。请安装相应的库。错误: {ie}")
except Exception as e:
print_error(f"创建引擎时发生意外错误: {e}")
使用 MetaData
, Table
, Column
对象在 Python 代码中定义数据库模式。
核心类与概念
类/概念 | 描述 |
---|---|
MetaData |
表和其他模式对象的容器。通常一个数据库一个实例。 |
Table |
代表数据库中的一张表。关联到 MetaData 。 |
Column |
代表表中的一列。包含名称、类型和约束。 |
SQLAlchemy Types | 独立于数据库的类型,如 Integer , String , DateTime , Boolean 等。 |
Constraints | 约束条件,如主键、外键、唯一、非空、检查约束、索引。 |
from sqlalchemy import MetaData, Table, Column, Integer, String
metadata = MetaData()
user_table = Table(
"user",
metadata,
Column("id", Integer, primary_key=True),
Column("name", String(50), nullable=False),
Column("email", String(120), unique=True),
)
简单来说,SQLAlchemy 中的 MetaData
对象就是一个数据库结构蓝图的“登记簿”。
它主要干两件事:
Table
对象)、列、索引、外键等这些“蓝图”信息。metadata.create_all(engine)
就能把所有登记在册的表都建出来Table
构造函数关键参数
参数 | 类型 | 描述 |
---|---|---|
name |
str |
必需. 表名。 |
metadata |
MetaData |
必需. 关联的 MetaData 对象。 |
*columns |
Column , Constraint , Index 等对象 |
必需. 定义表的列和表级约束。 |
**kwargs |
特定数据库方言的选项 (如 mysql_engine='InnoDB' )。 |
Column
构造函数关键参数
参数 | 类型 | 描述 |
---|---|---|
name |
str |
列名 (通常省略,SQLAlchemy 会使用属性名)。 |
type_ |
SQLAlchemy Type | 必需. 列的数据类型 (如 Integer , String(50) , DateTime )。 |
primary_key |
bool |
是否为主键。 |
nullable |
bool |
是否允许为空 (默认为 True )。 |
default |
Any |
Python 级别的默认值。 |
server_default |
str 或 SQL Expression |
数据库级别的默认值 (如 func.now() 或 "0" )。 |
unique |
bool |
是否唯一。 |
index |
bool |
是否为此列创建索引。 |
ForeignKey(...) |
ForeignKey 对象 |
定义外键约束,指向 目标表名.目标列名 。 |
onupdate |
Any 或 SQL Expression |
更新行时自动设置的值 (常用于时间戳)。 |
server_onupdate |
WorkspaceedValue 或 str 等 |
数据库级别的更新触发器。 |
comment |
str |
列注释 (部分数据库支持)。 |
主键、唯一、索引、默认值、非空等示例
from sqlalchemy import Table, Column, Integer, String, Boolean, DateTime, func
user_table = Table(
"user",
metadata,
Column("id", Integer, primary_key=True),
Column("username", String(50), nullable=False, unique=True, index=True),
Column("email", String(120), nullable=False),
Column("created_at", DateTime, server_default=func.now()), # 数据库默认值
Column("is_active", Boolean, default=True), # Python 级默认值
)
外键约束示例
from sqlalchemy import ForeignKey
address_table = Table(
"address",
metadata,
Column("id", Integer, primary_key=True),
Column("user_id", Integer, ForeignKey("user.id")),
Column("address", String(255)),
)
表级约束与索引示例
from sqlalchemy import Table, Column, Integer, String, UniqueConstraint, Index
metadata = MetaData()
# 定义名为 product 的表
product_table = Table(
"product", # 表名为 "product"
metadata, # 关联的 MetaData 对象
Column("id", Integer, primary_key=True), # 主键列,整数类型,自动增长
Column("name", String(100)), # 产品名称列,字符串,最大长度100
Column("sku", String(50)), # SKU(库存单位)列,字符串,最大长度50
# 表级唯一约束,确保 sku 值在整个表中唯一
UniqueConstraint("sku", name="uix_1"),
# 在 name 列上创建索引,加快查询速度(例如根据产品名称搜索)
Index("ix_product_name", "name"),
)
常用 SQLAlchemy 类型
类型 | 映射的常见 SQL 类型 (示例) |
---|---|
Integer / SmallInteger / BigInteger |
INTEGER , SMALLINT , BIGINT |
String(length) |
VARCHAR(length) |
Text |
TEXT , CLOB |
Numeric(prec, scale) |
NUMERIC(prec, scale) , DECIMAL |
Float(precision) |
FLOAT , REAL |
Boolean |
BOOLEAN , SMALLINT (0/1) |
Date |
DATE |
Time |
TIME |
DateTime |
TIMESTAMP , DATETIME |
LargeBinary |
BLOB , BYTEA |
JSON |
JSON , JSONB (需要数据库支持) |
# core/metadata_table_example.py
from sqlalchemy import (MetaData, Table, Column, Integer, String, DateTime,
Boolean, Numeric, ForeignKey, Index, PrimaryKeyConstraint,
UniqueConstraint, CheckConstraint, func, TIMESTAMP, Text) # func 用于 server_default
from print_utils import print_header, print_info, print_success, print_error
from sqlalchemy import create_engine
import pymysql
# 数据库配置
DB_CONFIG = {
"host": "localhost",
"port": 3306,
"user": "root",
"password": "root",
"db_name": "pymysql_demo"
}
def create_database():
"""创建数据库(如果不存在)"""
try:
# 连接到MySQL(不指定数据库)
conn = pymysql.connect(
host=DB_CONFIG["host"],
port=DB_CONFIG["port"],
user=DB_CONFIG["user"],
password=DB_CONFIG["password"]
)
cursor = conn.cursor()
# 检查数据库是否存在
cursor.execute(f"SHOW DATABASES LIKE '{DB_CONFIG['db_name']}'")
result = cursor.fetchone()
# 如果数据库不存在,则创建
if not result:
print_info(f"数据库 '{DB_CONFIG['db_name']}' 不存在,正在创建...")
cursor.execute(f"CREATE DATABASE {DB_CONFIG['db_name']} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci")
print_success(f"数据库 '{DB_CONFIG['db_name']}' 创建成功")
else:
print_info(f"数据库 '{DB_CONFIG['db_name']}' 已存在")
conn.commit()
cursor.close()
conn.close()
except Exception as e:
print_error(f"创建数据库时出错: {str(e)}")
raise e
def create_table():
print_header("定义 SQLAlchemy Core 表结构示例 (MySQL)")
# 1.创建 MetaData 对象
metadata_obj = MetaData()
print_info("Meta 对象已创建")
# 定义 "categories" 表
categories_table = Table(
"categories", # 表名
metadata_obj, # 关联的 MetaData 对象
Column("category_id", Integer, primary_key=True, autoincrement=True, comment="类别ID"), # 主键,自增
Column("name", String(50), nullable=False, unique=True, comment="类别名称"), # 非空、唯一、注释
mysql_engine="InnoDB", # 指定存储引擎
mysql_charset="utf8mb4", # 指定字符集
mysql_collate="utf8mb4_unicode_ci", # 指定排序规则
mysql_row_format="DYNAMIC" # 指定行格式
)
# 3. 定义 "products" 表
products_table = Table(
"products",
metadata_obj,
Column("id", Integer, primary_key=True, autoincrement=True, comment="产品ID"), # 主键,自增
Column("name", String(100), nullable=False, comment="产品名称"),
Column("price", Numeric(10, 2), nullable=False, comment="价格"),
Column("category_id", Integer, ForeignKey("categories.category_id"), nullable=False, comment="类别ID"), # 外键
# Mysql中常用TIMESTAMP类型,并设置自动更新
Column("created_at", TIMESTAMP, server_default=func.now(), comment="创建时间"),
Column("updated_at", TIMESTAMP, server_default=func.now(), onupdate=func.now(), comment="更新时间"),
Column("is_active", Boolean, default=True, comment="是否激活"),
mysql_engine="InnoDB", # 指定存储引擎,支持事务和外键
mysql_charset="utf8mb4", # 指定字符集,支持中文
mysql_collate="utf8mb4_unicode_ci", # 指定排序规则,支持中文
mysql_row_format="DYNAMIC" # 指定行格式,支持动态行
)
return metadata_obj, categories_table, products_table
# 4.创建表
def create_tables(engine, metadata_obj):
metadata_obj.create_all(bind=engine)
print_success("表创建完成")
def check_table(metadata_obj, products_table):
print_info("\nMetaData 中包含的表:")
for table_name in metadata_obj.tables:
print(f" - {table_name}")
print_info("\n'products' 表的列信息:")
for column in products_table.columns:
print(f" - 列名: {column.name}, 类型: {column.type}, 主键: {column.primary_key}, 外键: {column.foreign_keys}")
if __name__ == '__main__':
# 先确保数据库存在
create_database()
# 创建数据库连接
engine = create_engine(f"mysql+pymysql://{DB_CONFIG['user']}:{DB_CONFIG['password']}@{DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['db_name']}", echo=True)
# 创建表结构定义
metadata_obj, categories_table, products_table = create_table()
# 创建表
create_tables(engine, metadata_obj)
# 检查表结构
check_table(metadata_obj, products_table)
定义好表之后,就可以使用 SQLAlchemy Core 构建并执行 INSERT, SELECT, UPDATE, DELETE 语句了。
核心执行流程
with engine.connect() as connection:
或 with engine.begin() as connection:
(推荐,自动事务管理)。insert(table)
, select(columns)
, update(table)
, delete(table)
。.values(...)
, .where(...)
, .order_by(...)
, .limit(...)
, .offset(...)
, .join(...)
等方法。result = connection.execute(statement, [parameters])
。engine.connect()
, DML 操作后需 connection.commit()
或 connection.rollback()
。如果使用 engine.begin()
, 事务自动处理。常用语句构建函数/方法
函数/方法 | 用途 | 示例 |
---|---|---|
insert(table) |
构建 INSERT 语句。 | stmt = insert(products_table).values(name="P1", price=10) |
select(...) |
构建 SELECT 语句。参数是列对象或表对象。 | stmt = select(products_table.c.name, products_table.c.price) |
update(table) |
构建 UPDATE 语句。 | stmt = update(products_table).where(products_table.c.id == 1) |
delete(table) |
构建 DELETE 语句。 | stmt = delete(products_table).where(products_table.c.stock == 0) |
.values(...) |
(INSERT/UPDATE) 指定要插入或更新的列和值。 | .values(name="P2", price=20) |
.where(...) |
(SELECT/UPDATE/DELETE) 指定过滤条件。 | .where(products_table.c.price > 100) |
.order_by(...) |
(SELECT) 指定排序方式 (.asc() , .desc() )。 |
.order_by(products_table.c.price.desc()) |
.limit(n) |
(SELECT) 限制返回结果数量。 | .limit(10) |
.offset(n) |
(SELECT) 跳过指定数量的结果 (用于分页)。 | .offset(20) |
.join(...) |
(SELECT) 连接其他表。参数通常是目标表和连接条件。 | .join(categories_table, products_table.c.category_id == ...) |
.label(name) |
给列或表达式设置别名。 | select(products_table.c.name.label("product_name")) |
text(sql) |
用于执行原生 SQL 字符串 (配合参数绑定使用)。 | stmt = text("SELECT * FROM products WHERE id = :id") |
常用条件操作符/函数
操作符/函数 | 用途 | 示例 |
---|---|---|
== , != , < , > |
基本比较。 | table.c.price == 100 |
& / and_(...) |
逻辑与 (AND)。 | (table.c.price > 10) & (table.c.stock > 0) |
` | / or_(…)` |
逻辑或 (OR)。 |
~ / not_(...) |
逻辑非 (NOT)。 | ~table.c.name.like('A%') |
.like(pattern) |
SQL LIKE 操作 ( % 通配符)。 |
table.c.name.like('%Core%') |
.ilike(pattern) |
不区分大小写的 LIKE (某些数据库)。 | table.c.name.ilike('%core%') |
.in_([...]) |
SQL IN 操作。 | table.c.category_id.in_([1, 3, 5]) |
.is_(None) |
SQL IS NULL。 | table.c.added_date.is_(None) |
.isnot(None) |
SQL IS NOT NULL。 | table.c.description.isnot(None) |
.between(a, b) |
SQL BETWEEN 操作。 | table.c.price.between(100, 500) |
func. |
SQL 函数 (如 func.now() , func.count() ) |
select(func.count(products_table.c.id)) |
结果处理 (ResultProxy
/ CursorResult
)
connection.execute()
返回一个结果对象,用于获取数据。
方法 | 描述 |
---|---|
Workspaceall() |
获取所有行,返回 Row 对象列表。 |
Workspaceone() |
获取下一行,返回 Row 对象或 None 。 |
Workspacemany(size=None) |
获取指定数量的行,返回 Row 对象列表。 |
scalar() |
获取结果的第一行第一列的值,无结果或多列时行为可能变化或报错。 |
scalars() (2.0+) |
返回 ScalarResult ,迭代产生每行第一列的值。 |
first() |
获取第一行 (Row 对象),无结果返回 None 。常用。 |
one() |
获取唯一一行,无结果或多于一行时报错。 |
one_or_none() |
获取唯一一行,无结果返回 None ,多于一行时报错。 |
mappings() (2.0+) |
返回 MappingResult ,迭代产生字典形式的行。 |
(直接迭代结果对象) | 每次迭代返回一个 Row 对象。 |
rowcount (属性) |
DML 操作影响的行数。 |
inserted_primary_key (属性) |
INSERT 操作的自增主键 (元组)。可能只包含第一个插入行的主键。 |
keys() (属性) |
返回结果的列名列表。 |
Row
对象访问:
row[0]
row['column_name']
或 row.column_name
row[table.c.column_name]
dict(row._mapping)
(SQLAlchemy 2.0+) 或 dict(row)
(旧版)# core/metadata_table_example.py
from sqlalchemy import (MetaData, Table, Column, Integer, String, DateTime,
Boolean, Numeric, ForeignKey, Index, PrimaryKeyConstraint,
UniqueConstraint, CheckConstraint, func, TIMESTAMP, Text) # func 用于 server_default
from print_utils import print_header, print_info, print_success, print_error, print_subheader, print_sql,print_warning
from sqlalchemy import create_engine
# --- 固定的代码,请放在最上面 ---
DB_CONFIG = {
'DB_URI': 'mysql+pymysql://root:root@localhost:3306/pymysql_demo'
}
engine = create_engine(DB_CONFIG['DB_URI'])
products_table: Table = Table('products', MetaData(),
Column('id', Integer, primary_key=True),
Column('name', String(50), nullable=False),
Column('price', Numeric(10, 2), nullable=False),
Column('category_id', Integer, ForeignKey('categories.category_id'), nullable=False),
Column('is_active', Boolean, default=True)
)
categories_table = Table('categories', MetaData(),
Column('category_id', Integer, primary_key=True),
Column('name', String(50), nullable=False)
)
def insert_category_data():
"""
插入类别表数据
"""
print_header("SQLAlchemy 类别表 插入示例")
###### 插入单条记录 (INSERT Single)
print_subheader("1. 插入单个产品记录")
stmt_insert_single = insert(categories_table).values(
name="书籍类"
)
print_sql(str(stmt_insert_single)) # 打印生成的 SQL
inserted_pk = None
try:
with engine.begin() as connection: # 自动事务管理
result = connection.execute(stmt_insert_single)
if result.inserted_primary_key:
inserted_pk = result.inserted_primary_key[0] # 获取自增主键
print_success(f"成功插入类别 '书籍类' (ID: {inserted_pk})")
elif result.rowcount == 1:
print_success("成功插入类别 '书籍类' (无法获取主键)")
else:
print_warning("插入类别 '书籍类' 时 rowcount 不为 1")
except Exception as e:
print_error(f"插入类别 '书籍类' 时出错: {e}")
def insert_product_data():
"""
插入产品表数据
"""
print_header("SQLAlchemy 产品表 插入示例")
# 先获取已插入的类别ID
with engine.connect() as conn:
result = conn.execute(categories_table.select().where(categories_table.c.name == "书籍类"))
category = result.fetchone()
if not category:
print_error("未找到'书籍类'类别")
return
category_id = category[0] # 获取实际的category_id值
# 使用具体的category_id值插入产品
stmt_insert_single = insert(products_table).values(
name="SQLAlchemy书籍",
price=123.45,
category_id=category_id, # 使用具体的ID值
is_active=True
)
print_sql(str(stmt_insert_single)) # 打印生成的 SQL
inserted_pk = None
try:
with engine.begin() as connection: # 自动事务管理
result = connection.execute(stmt_insert_single)
if result.inserted_primary_key:
inserted_pk = result.inserted_primary_key[0] # 获取自增主键
print_success(f"成功插入产品 'SQLAlchemy书籍' (ID: {inserted_pk})")
elif result.rowcount == 1:
print_success("成功插入产品 'SQLAlchemy书籍' (无法获取主键)")
else:
print_warning("插入产品 'SQLAlchemy书籍' 时 rowcount 不为 1")
except Exception as e:
print_error(f"插入单个产品时出错: {e}")
if __name__ == '__main__':
# insert_category_data() 先执行类别表插入
insert_product_data()
def insert_many_products():
print_header("SQLAlchemy 批量插入示例")
print_subheader("1. 批量插入多个产品记录")
stmt_insert_many = insert(products_table).values([
{'name': 'Python编程', 'price': 123.45, 'category_id': 1},
{'name': 'SQLAlchemy教程', 'price': 99.99, 'category_id': 1},
{'name': '大话西游之月光宝盒', 'price': 150.00, 'category_id': 1}
])
print_sql(str(stmt_insert_many))
try:
with engine.begin() as connection: # 使用事务
result = connection.execute(stmt_insert_many)
print_success(f"成功插入 {result.rowcount} 条记录")
except Exception as e:
print_error(f"插入失败: {e}")
def update_product():
print_header("SQLAlchemy 更新示例")
# 更新商品价格
stmt_update = update(products_table).where(products_table.c.id == 1).values(
price = 999999
)
print_sql(str(stmt_update))
try:
with engine.begin() as connection:
result = connection.execute(stmt_update)
print_success(f"成功更新 {result.rowcount} 条记录")
except Exception as e:
print_error(f"更新失败: {e}")
def update_products():
print_header("SQLAlchemy 批量更新示例")
category_id_to_update = 1
price_increase_factor = 1.10 # 价格提高10%
# 构建 UPDATE 语句
stmt_update = update(products_table).where(
products_table.c.category_id == category_id_to_update
).values(
price = products_table.c.price * price_increase_factor
)
print_sql(str(stmt_update))
try:
with engine.begin() as connection:
result = connection.execute(stmt_update)
print_success(f"成功更新 {result.rowcount} 条记录")
except Exception as e:
print_error(f"更新失败: {e}")
def delete_product():
print_header("SQLAlchemy 删除示例")
product_id_to_delete = 1
stmt_delete = delete(products_table).where(products_table.c.id == product_id_to_delete)
print_sql(str(stmt_delete))
try:
with engine.begin() as connection:
result = connection.execute(stmt_delete)
print_success(f"成功删除 {result.rowcount} 条记录")
except Exception as e:
print_error(f"删除失败: {e}")
def delete_products():
print_header("SQLAlchemy 批量删除示例")
# 批量删除数据,例如:isActive = False
stmt_delete = delete(products_table).where(products_table.c.is_active == False)
print_sql(str(stmt_delete))
try:
with engine.begin() as connection:
result = connection.execute(stmt_delete)
print_success(f"成功删除 {result.rowcount} 条记录")
except Exception as e:
print_error(f"删除失败: {e}")
在进入最重要的查询的基础前,我们可以看到代码有很多是重复的,每一次都要进行begin,捕获…这个繁杂的过程会导致代码冗余,我们可以采用AOP的思想,去实现一个事务的装饰器,如下:
from functools import wraps
# 添加事务装饰器 - AOP思想实现
def transactional(func=None, *, engine_obj=None):
"""
使用AOP思想实现的事务装饰器
可以直接应用于函数上,无需手动管理事务
用法:
@transactional # 使用默认引擎
def my_function():
# 执行SQL操作,自动事务管理
@transactional(engine_obj=custom_engine) # 使用自定义引擎
def another_function():
# 执行SQL操作,自动事务管理
"""
_engine = engine_obj or engine # 如果未指定引擎,使用全局引擎
def decorator(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
# 创建connection参数
if 'connection' in kwargs:
# 如果已经有连接传入,则使用该连接
return fn(*args, **kwargs)
try:
with _engine.begin() as connection:
# 注入connection到被装饰函数的kwargs中
kwargs['connection'] = connection
result = fn(*args, **kwargs)
return result
except Exception as e:
print_error(f"事务执行失败: {e}")
raise
return wrapper
# 处理@transactional和@transactional()两种调用方式
if func is None:
return decorator
return decorator(func)
通过这个装饰器,我们就可以实现代码的优雅性
# 使用事务装饰器实现的方法
@transactional
def insert_many_products_aop(connection=None):
"""使用AOP事务装饰器实现的批量插入"""
print_header("SQLAlchemy 批量插入示例 (使用事务装饰器)")
stmt_insert_many = insert(products_table).values([
{'name': 'Python编程AOP版', 'price': 123.45, 'category_id': 1},
{'name': 'SQLAlchemy教程AOP版', 'price': 99.99, 'category_id': 1},
{'name': '大话西游之月光宝盒AOP版', 'price': 150.00, 'category_id': 1}
])
print_sql(str(stmt_insert_many))
# 不需要手动管理事务,直接执行
result = connection.execute(stmt_insert_many)
print_success(f"成功插入 {result.rowcount} 条记录")
@transactional
def select_all_products(connection=None):
# --- 查询所有产品的所有列 ---
print_info("查询所有产品的 '*' (前 5 条):")
stmt_select_all = select(products_table).limit(5)
print_sql(str(stmt_select_all))
result = connection.execute(stmt_select_all)
for row in result.fetchmany(5):
print(row)
# core/metadata_table_example.py
from sqlalchemy import and_, or_, not_ # 导入逻辑操作符
from sqlalchemy import * # func 用于 server_default
from print_utils import print_header, print_info, print_success, print_error, print_subheader, print_sql, print_warning
# ====================== 基础查询 ======================
@transactional
def select_all_products(connection=None):
print_info("查询所有产品的 '*' (前 5 条):")
stmt_select_all = select(products_table).limit(5)
print_sql(str(stmt_select_all))
result = connection.execute(stmt_select_all)
for row in result.fetchmany(5):
print(row)
# ====================== 条件查询 ======================
@transactional
def select_by_category_id(category_id_to_find, connection=None):
print_info(f"\n查询类别 ID 为 {category_id_to_find} 的产品:")
stmt_eq = select(products_table.c.id, products_table.c.name, products_table.c.price).where(
products_table.c.category_id == category_id_to_find
)
print_sql(str(stmt_eq))
result = connection.execute(stmt_eq)
for row in result:
print(row)
@transactional
def select_by_price_gt(max_price, connection=None):
print_info(f"\n查询价格大于 {max_price} 的产品名称和价格 (降序):")
stmt_gt = select(products_table.c.name, products_table.c.price).where(
products_table.c.price > max_price
).order_by(products_table.c.price.desc())
print_sql(str(stmt_gt))
result = connection.execute(stmt_gt)
for row in result:
print(row)
@transactional
def select_by_price_between(min_p, max_p, connection=None):
print_info(f"\n查询价格在 {min_p} 到 {max_p} 之间的产品名称和价格 (升序):")
stmt_between = select(products_table.c.name, products_table.c.price).where(
products_table.c.price.between(min_p, max_p)
).order_by(products_table.c.price)
print_sql(str(stmt_between))
result = connection.execute(stmt_between)
for row in result:
print(row)
@transactional
def select_by_category_ids(target_cat_ids, connection=None):
print_info(f"\n查询类别 ID 在 {target_cat_ids} 中的产品:")
stmt_in = select(products_table.c.id, products_table.c.name, products_table.c.category_id).where(
products_table.c.category_id.in_(target_cat_ids)
)
print_sql(str(stmt_in))
result = connection.execute(stmt_in)
for row in result:
print(row)
@transactional
def select_by_name_like(search_term, connection=None):
print_info(f"\n查询名称包含 '{search_term}' 的产品:")
stmt_like = select(products_table.c.name, products_table.c.price).where(
products_table.c.name.like(f"%{search_term}%")
)
print_sql(str(stmt_like))
result = connection.execute(stmt_like)
for row in result:
print(row)
@transactional
def alter_add_stock_column(connection=None):
print_info("增加 stock 字段到 products 表:")
stmt_alter = text("ALTER TABLE products ADD COLUMN stock INTEGER NOT NULL")
print_sql(str(stmt_alter))
connection.execute(stmt_alter)
@transactional
def select_by_complex_condition(cat_id, min_stock, connection=None):
print_info(f"\n查询类别 ID 为 {cat_id} 且 (价格小于 50 或 库存大于 {min_stock}) 的产品:")
stmt_complex = select(products_table).where(
and_(
products_table.c.category_id == cat_id,
or_(
products_table.c.price < 50.0,
products_table.c.stock > min_stock
)
)
)
print_sql(str(stmt_complex))
result = connection.execute(stmt_complex)
for row in result:
print(row)
if __name__ == '__main__':
# 基础查询示例
select_all_products()
# 条件查询示例
select_by_category_id(1)
select_by_price_gt(100.0)
select_by_price_between(100.0, 500.0)
select_by_category_ids([1, 3])
select_by_name_like("Core")
select_by_complex_condition(1, 10)
print_subheader("7. 查询数据 (排序 ORDER BY)")
@transactional
def select_products_order_by_price_desc(connection=None):
print_info("\n按价格降序排列产品 (前 5 条):")
stmt_order_desc = select(
products_table.c.name, products_table.c.price
).order_by(
products_table.c.price.desc() # 使用 .desc() 指定降序
).limit(5)
print_sql(str(stmt_order_desc))
result = connection.execute(stmt_order_desc)
for row in result:
print(row)
@transactional
def select_products_order_by_multi_columns(connection=None):
print_info("\n按类别 ID 升序、价格降序排列产品 (前 5 条):")
stmt_order_multi = select(
products_table.c.name, products_table.c.category_id, products_table.c.price
).order_by(
products_table.c.category_id.asc(), # 先按类别 ID 升序
products_table.c.price.desc() # 同类别内按价格降序
).limit(5)
print_sql(str(stmt_order_multi))
result = connection.execute(stmt_order_multi)
for row in result:
print(row)
@transactional
def select_distinct_products():
print_info("\n查询所有类别 ID 并去重:")
stmt_distinct = select(products_table.c.category_id,products_table.c.name).distinct()
print_sql(str(stmt_distinct))
result = connection.execute(stmt_distinct)
for row in result:
print(row)
@transactional
def select_aggregation_products(connection=None):
print_info("\n计算产品统计信息 (总数, 总库存, 平均价, 最高价, 最低价):")
stmt_aggregation = select(
func.count(products_table.c.id).label("total_products"), # 计算产品总数并取别名
func.sum(products_table.c.stock).label("total_stock"), # 计算总库存并取别名
func.avg(products_table.c.price).label("average_price"), # 计算平均价格并取别名
func.max(products_table.c.price).label("max_price"), # 计算最高价格并取别名
func.min(products_table.c.price).label("min_price") # 计算最低价格并取别名
)
print_sql(str(stmt_aggregation))
result = connection.execute(stmt_aggregation).first() # 聚合查询通常只有一行元组结果
if result:
print_result_item(dict(result._mapping))
else:
print_info("未能获取产品统计信息。")
@transactional
def select_group_by_category(connection=None):
print_info("\n按类别分组计算产品统计信息:")
stmt_group_by = select(
products_table.c.category_id,
func.count(products_table.c.id).label("product_count"),
func.sum(products_table.c.price).label("price"),
).group_by(products_table.c.category_id).order_by(products_table.c.category_id.asc())
print_sql(str(stmt_group_by))
result = connection.execute(stmt_group_by).first()
if result:
print_result_item(dict(result._mapping))
else:
print_info("未能获取产品统计信息。")
@transactional
def select_group_by_category(connection=None):
print_info("\n按类别分组计算产品统计信息 (产品数量大于1的类别):")
stmt_group_by = select(
products_table.c.category_id, # 类别 ID
func.count(products_table.c.id).label("product_count"), # 产品数量
func.sum(products_table.c.price).label("price"), # 价格总和
).group_by(products_table.c.category_id)\
.having(
and_(
func.count(products_table.c.id) > 1, # 产品数量大于1
func.sum(products_table.c.price) > 0 # 价格总和大于0
)
)\
.order_by(products_table.c.category_id.asc())
print_sql(str(stmt_group_by))
result = connection.execute(stmt_group_by).first()
if result:
print_result_item(dict(result._mapping))
else:
print_info("未能获取产品统计信息。")
@transactional
def select_limit_pages(connection=None, page=1, page_size=3):
print_info(f"\n分页查询产品数据 (每页 {page_size} 条, 第 {page} 页):")
offset_val = (page - 1) * page_size
print_info(f"当前页码: {page}, 每页显示 {page_size} 条, 偏移量: {offset_val}")
# 获取总记录数
count_stmt = select(func.count()).select_from(products_table)
total_count = connection.execute(count_stmt).scalar()
total_pages = (total_count + page_size - 1) // page_size
# 分页查询
stmt_paging = select(products_table.c.name, products_table.c.price)\
.offset(offset_val)\
.limit(page_size)
print_sql(str(stmt_paging))
result = connection.execute(stmt_paging)
print_info(f"总记录数: {total_count}, 总页数: {total_pages}")
print_info("当前页数据:")
for row in result:
print_result_item(dict(row._mapping))
if __name__ == '__main__':
# 演示分页效果
for page in range(1,10):
select_limit_pages(page=page)
print_info("-" * 50)
@transactional
def select_products_with_join(connection=None):
print_info("\n内连接查询: 获取产品及其对应的类别名称 (只显示有类别的产品):")
stmt_inner_join = select(
products_table.c.name,
products_table.c.price,
categories_table.c.name.label("category_name") # 使用label设置别名
).join(
categories_table, # 指定要连接的表
products_table.c.category_id == categories_table.c.category_id # 指定连接条件
).where(
products_table.c.category_id == 1 # 指定过滤条件
) # 如果 ForeignKey 已定义,SQLAlchemy 通常能自动推断条件: .join(categories_table)
print_sql(str(stmt_inner_join))
result = connection.execute(stmt_inner_join)
for row in result:
print_result_item(dict(row._mapping))
@transactional
def select_products_with_level(connection=None):
print_info("\n使用 CASE 表达式根据价格给产品分类:")
stmt_case = select(
products_table.c.name,
products_table.c.price,
case(
(products_table.c.price > 500, "白金级"),
(products_table.c.price > 200, "进阶级"),
(products_table.c.price > 100, "标准级"),
(products_table.c.price > 10, "入门级"),
else_="未知"
).label("price_level")
)
print_sql(str(stmt_case))
result = connection.execute(stmt_case)
for row in result:
print_result_item(dict(row._mapping))
子查询方法快速参考
为了方便快速选择合适的子查询构建方式,这里根据核心使用场景进行了归纳:
方法/概念 | 核心场景 (一句话概括) | 关键点/提示 |
---|---|---|
.scalar_subquery() |
WHERE 中与单个预期值比较 | 子查询应返回单行单列。用于 >、=、< 等比较。 |
.subquery() |
FROM 或 JOIN 中像表一样使用子查询结果 | 像操作普通 Table 一样操作它;用 .c 访问列;聚合列需用 .label() 命名。 |
(直接用 select() ) |
WHERE 中与列表比较 (如 IN ) |
子查询应返回单列多行(如id);常与 .in_() 配合。 |
.exists() |
WHERE 中检查是否存在满足条件的关联行 | 只关心“有没有”,不关心“是什么”或“有多少”。 |
.alias() |
需要在同一查询中区分同一个表的多次引用时 | 相关子查询和自连接的必备工具。 |
.correlate() |
(较少用) 显式声明子查询依赖的外部表 | 主要用于非 WHERE /FROM 子句的子查询,或自动关联失效时。 |
.lateral() |
(需数据库支持) 子查询引用同 FROM 中之前的表 | 用于行级计算或复杂的 Top-N 查询。 |
选择思路小结:
.scalar_subquery()
.subquery()
select()
配合 .in_()
.exists()
.alias()
@transactional
def select_products_with_subquery(connection=None):
""" --- 子查询用在 WHERE 子句中 ---"""
print_info("\n子查询: 查找价格高于平均价格的产品:")
# 1.构建子查询,计算平均价格
# scalar_subquery : 返回一个标量值而不是一个结果集,用于where子句中比较
subquery_avg_price = select(func.avg(products_table.c.price)).scalar_subquery()
# 2.构建主查询,在 WHERE 子句中使用子查询结果
stmt_subquery_where = select(products_table.c.name, products_table.c.price).where(
products_table.c.price > subquery_avg_price
)
print_sql(str(stmt_subquery_where))
result = connection.execute(stmt_subquery_where)
for row in result:
print_result_item(dict(row._mapping))
@transactional
def select_products_with_subquery_in_select(connection=None):
""" --- 子查询用在 SELECT 子句中 ---"""
print_info("\n子查询: 查找每个类别中价格最高的产品 (使用派生表):")
# 1.构建子查询:找出每个 category_id 对应的最高价格
subquery_max_price = select(
products_table.c.category_id,
func.max(products_table.c.price).label("max_price")
).group_by(products_table.c.category_id).subquery()
# subquery 表示一个子查询,可以作为主查询的一部分 也可以作为一个用来join的对象
# 2. 将主表与子查询连接,找出价格等于该类别最高价的产品
stmt_subquery_from = select(
products_table.c.name,
products_table.c.price,
products_table.c.category_id,
).join( # 将主表与子查询连接
subquery_max_price,
and_( # 指定连接条件
products_table.c.category_id == subquery_max_price.c.category_id,
products_table.c.price == subquery_max_price.c.max_price # 价格等于该类别最高价
)
).order_by(products_table.c.category_id)
print_sql(str(stmt_subquery_from))
result = connection.execute(stmt_subquery_from)
for row in result:
print_result_item(dict(row._mapping))
@transactional
def select_products_with_in_subquery(connection=None, min_price=100, max_price=200):
""" --- 子查询用在 IN 子句中 ---"""
print_info(f"\n子查询: 查找价格在 {min_price} 到 {max_price} 之间的产品,以及这些产品所属的类别:")
# 1.构建子查询,找出价格在指定范围内的产品ID
subquery_product_ids = select(products_table.c.id).where(
products_table.c.price.between(min_price, max_price)
).scalar_subquery()
# 2.使用IN子句查询符合条件的产品及其类别信息
stmt = select(
products_table.c.name,
products_table.c.price,
categories_table.c.name.label("category_name")
).join(
categories_table,
products_table.c.category_id == categories_table.c.category_id
).where(
products_table.c.id.in_(subquery_product_ids)
)
print_sql(str(stmt))
result = connection.execute(stmt)
for row in result:
print_result_item(dict(row._mapping))
@transactional
def select_products_with_alias_self_comparison(connection=None):
""" --- 子查询用于自连接比较 ---"""
print_info("\n子查询: 查找比同类别平均价格高的产品:")
# 1.创建产品表的别名,用于自连接
avg_price_by_category = select(
products_table.c.category_id,
func.avg(products_table.c.price).label("avg_price")
).group_by(products_table.c.category_id).subquery() # subquery 表示一个子查询,可以作为主查询的一部分 也可以作为一个用来join的对象
# 2.构建主查询,连接原表和子查询
stmt = select(
products_table.c.name,
products_table.c.category_id,
products_table.c.price,
avg_price_by_category.c.avg_price
).join(
avg_price_by_category,
products_table.c.category_id == avg_price_by_category.c.category_id,
).where(
products_table.c.price > avg_price_by_category.c.avg_price
)
print_sql(str(stmt))
result = connection.execute(stmt)
for row in result:
print_result_item(dict(row._mapping))
SQLAlchemy 通过 sqlalchemy.func
这个特殊对象,提供了一种与数据库无关的方式来调用 SQL 内置函数。这意味着你可以用统一的 Python 语法(如 func.count()
, func.now()
, func.lower()
)来生成对应数据库(如 MySQL, PostgreSQL, SQLite)的特定函数调用(如 COUNT()
, NOW()
, LOWER()
)。SQLAlchemy 的方言 (Dialect) 会负责进行正确的语法转换。
常用 func
函数调用参考表
下表列出了一些常用的 SQL 函数及其通过 func
调用的方式:
func 调用示例 |
对应 SQL 函数 (常见) | 用途说明 | 分类 |
---|---|---|---|
func.count(col) / func.count() |
COUNT(col) / COUNT(*) |
计数 (指定列或所有行) | 聚合 |
func.sum(col) |
SUM(col) |
求和 | 聚合 |
func.avg(col) |
AVG(col) |
平均值 | 聚合 |
func.max(col) |
MAX(col) |
最大值 | 聚合 |
func.min(col) |
MIN(col) |
最小值 | 聚合 |
func.count(distinct(col)) |
COUNT(DISTINCT col) |
计算非重复值的数量 | 聚合 |
func.now() |
NOW() , CURRENT_TIMESTAMP |
获取当前日期时间 (时间戳) | 日期/时间 |
func.current_date() |
CURRENT_DATE |
获取当前日期 | 日期/时间 |
func.current_time() |
CURRENT_TIME |
获取当前时间 | 日期/时间 |
func.extract(field, date_col) |
EXTRACT(field FROM date_col) |
提取日期/时间部分 (‘year’, ‘month’) | 日期/时间 |
func.lower(str_col) |
LOWER(str_col) |
字符串转小写 | 字符串 |
func.upper(str_col) |
UPPER(str_col) |
字符串转大写 | 字符串 |
func.length(str_col) |
LENGTH(str_col) , LEN() |
获取字符串长度 | 字符串 |
func.concat(*args) |
CONCAT(arg1, arg2, ...) |
字符串拼接 | 字符串 |
func.substring(str, start, len) |
SUBSTRING(str, start, len) , SUBSTR() |
提取子字符串 | 字符串 |
func.abs(num_col) |
ABS(num_col) |
绝对值 | 数学 |
func.round(num_col, digits) |
ROUND(num_col, digits) |
四舍五入 | 数学 |
func.random() |
RAND() , RANDOM() |
生成随机数 (具体行为看数据库) | 其他 |
注意: 某些函数(特别是日期/时间函数)的具体名称和行为可能因数据库方言而异,但通过 func
调用通常能提供较好的兼容性。
代码示例 (完整版)
# ======== SQLAlchemy 数据库函数示例 ========
def demonstrate_sqlalchemy_functions():
"""
展示 SQLAlchemy 中各种数据库函数的使用方法
包括聚合函数、字符串函数、日期时间函数、数学函数和条件函数
"""
print_subheader("16. 使用 SQLAlchemy func 调用数据库函数 - 完整示例")
# 演示聚合函数
#demonstrate_aggregate_functions()
# 演示字符串函数
#demonstrate_string_functions()
# 演示日期/时间函数
#demonstrate_datetime_functions()
# 演示数学函数
#demonstrate_math_functions()
# 演示条件函数
demonstrate_conditional_functions()
@transactional
def demonstrate_aggregate_functions(connection=None):
"""展示 SQLAlchemy 聚合函数的使用"""
print_info("\n聚合函数示例:产品统计")
stmt_agg_complete = select(
func.count(products_table.c.id).label("total_products"),
func.count(distinct(products_table.c.category_id)).label("distinct_categories"),
func.sum(products_table.c.stock).label("total_stock"),
func.avg(products_table.c.price).label("average_price"),
func.max(products_table.c.price).label("max_price"),
func.min(products_table.c.price).label("min_price")
)
print_sql(str(stmt_agg_complete))
try:
result = connection.execute(stmt_agg_complete).first()
if result:
print_info("聚合统计结果:")
print_result_item(dict(result._mapping))
else:
print_info("未能获取产品统计信息。")
except Exception as e:
print_error(f"执行聚合查询时出错: {e}")
@transactional
def demonstrate_datetime_functions(connection=None):
"""展示 SQLAlchemy 日期/时间函数的使用"""
print_info("\n日期/时间函数示例:数据库当前时间和日期部分提取")
stmt_func_datetime_complete = select(
func.now().label("now"),
func.current_date().label("current_date"),
func.current_time().label("current_time"),
func.current_timestamp().label("current_timestamp"),
func.sysdate().label("sysdate"),
func.localtime().label("localtime"),
func.localtimestamp().label("localtimestamp"),
func.extract('year', func.current_timestamp()).label("year"),
func.extract('month', func.current_timestamp()).label("month"),
func.extract('day', func.current_timestamp()).label("day"),
func.extract('hour', func.current_timestamp()).label("hour"),
func.extract('minute', func.current_timestamp()).label("minute"),
func.extract('second', func.current_timestamp()).label("second"),
)
print_sql(str(stmt_func_datetime_complete))
result = connection.execute(stmt_func_datetime_complete)
print_info("\n日期时间函数结果:")
print_info("-" * 80)
print_info(f"{'函数名':<20} {'值':<30}")
print_info("-" * 80)
for row in result:
for key, value in row._mapping.items():
print_info(f"{key:<20} {str(value):<30}")
print_info("-" * 80)
@transactional
def demonstrate_math_functions(connection=None):
"""展示 SQLAlchemy 数学函数的使用"""
print_info("\n数学函数示例:数学运算")
stmt_func_math = select(
func.abs(-10).label("绝对值"),
func.ceil(10.5).label("向上取整"),
func.floor(10.5).label("向下取整"),
func.sign(-10).label("符号"),
func.sqrt(16).label("平方根"),
func.power(2, 3).label("幂运算"),
func.mod(10, 3).label("取模"),
func.round(10.5).label("四舍五入"),
)
print_sql(str(stmt_func_math))
try:
result = connection.execute(stmt_func_math)
print_info("\n数学函数结果:")
print_info("-" * 80)
print_info(f"{'函数名':<20} {'值':<30}")
print_info("-" * 80)
for row in result:
for key, value in row._mapping.items():
print_info(f"{key:<20} {str(value):<30}")
print_info("-" * 80)
except Exception as e:
print_error(f"执行数学函数查询时出错: {e}")
if __name__ == '__main__':
demonstrate_sqlalchemy_functions()
ORM (Object Relational Mapper) 将数据库表映射为 Python 类,允许通过对象进行数据库操作。它是 SQLAlchemy 的核心功能,构建于 Core 之上。
核心概念: 类 (Model) <-> 表 (Table), 对象 (Instance) <-> 行 (Row), 属性 (Attribute) <-> 列 (Column)。
建议项目结构 (用于 ORM 示例):
sqlalchemy_orm_practice/ # 根目录
├── __init__.py
├── core_config.py # Engine 和 SessionLocal 设置 (使用 PyMySQL)
├── models/
│ ├── __init__.py
│ ├── base.py # DeclarativeBase, AbstractBaseModel, Mixins
│ └── orm_models.py # CategoryORM, ProductORM 模型定义
├── crud/
│ ├── __init__.py
│ └── orm_crud_ops.py # CRUD 操作示例类
├── utils/
│ ├── __init__.py
│ └── print_utils.py # 打印工具
└── main_orm_runner.py # 主运行脚本 (创建表, 调用示例)
Session
是与数据库交互的接口,管理对象状态 (Unit of Work) 和事务。
sqlalchemy_orm_practice/core_config.py
sessionmaker
(工厂) -> SessionLocal
(配置好的 Session 类) -> session = SessionLocal()
(实例)。sessionmaker
): bind=engine
(必需), autocommit=False
(默认), autoflush=False
(推荐), expire_on_commit=True
(默认)。with
): with SessionLocal() as session:
自动管理 session.close()
。方法 | 用途说明 |
---|---|
session.add(obj) |
添加新对象实例 (标记为待 INSERT)。 |
session.add_all(list) |
添加多个新对象实例。 |
session.delete(obj) |
标记持久化对象为待 DELETE。 |
session.commit() |
Flush 挂起更改并提交事务。 |
session.rollback() |
回滚事务,撤销更改。 |
session.flush() |
将更改同步到 DB (不提交事务),获取自增 ID 等。 |
session.get(Model, pk) |
通过主键高效获取单个对象。 |
session.execute(stmt) |
(2.0+) 执行 Core 语句 (ORM 查询主要方式)。 |
代码示例: Engine 和 Session 设置 (使用 PyMySQL)
"""
SQLAlchemy ORM 核心配置模块
此模块包含数据库连接引擎和会话的配置。
提供了创建数据库引擎和会话工厂的功能,用于整个应用程序的数据库交互。
主要组件:
- engine: 数据库连接引擎
- SessionLocal: 本地会话工厂,用于创建数据库会话
"""
from sqlalchemy import create_engine # 数据库引擎
from sqlalchemy.orm import sessionmaker # 会话工厂
from utils.print_utils import print_info, print_success, print_error # 打印工具
# --- 数据库连接配置 (使用 PyMySQL) ---
# !!! 生产环境应使用更安全的方式管理密码 !!!
DB_USER = "root"
DB_PASSWORD = "root"
DB_HOST = "localhost"
DB_PORT = 3306
DB_NAME = "sqlalchemy_orm_db" # 确保此数据库已创建
DB_CHARSET = "utf8mb4"
ORM_DATABASE_URL = f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}?charset={DB_CHARSET}"
print_info(f"ORM 数据库连接 URL: {ORM_DATABASE_URL.replace(DB_PASSWORD, '******')}") # 打印时隐藏密码
# --- 创建 Engine(针对 ORM) ---
orm_engine = create_engine(
ORM_DATABASE_URL,
echo=True, # 打印 SQL 语句
future=True, # 启用 SQLAlchemy 2.0 新特性
pool_size=10, # 连接池大小
max_overflow=20, # 连接池溢出时最多创建的连接数
pool_recycle=3600, # 连接池中连接的最大空闲时间,超过此时间的连接会被自动关闭
)
print_success("ORM 数据库引擎创建成功")
# --- 创建 SessionLocal 工厂 ---
SessionLocal = sessionmaker(
autocommit=False, # 自动提交
autoflush=False, # 推荐 - 关闭自动刷新
bind=orm_engine, # 绑定引擎
expire_on_commit=True # 会话在提交后过期
)
print_success("ORM 会话工厂创建成功")
# 使用示例 (通常在其他模块中):
# with SessionLocal() as session:
# # ... use session ...
# session.commit() # or session.rollback()
使用带类型注解的 Python 类映射数据库表 (SQLAlchemy 2.0 风格)。
文件: sqlalchemy_orm_practice/models/base.py
, sqlalchemy_orm_practice/models/orm_models.py
知识点:
DeclarativeBase
: 模型基类 (class Base(DeclarativeBase):
)。Base.metadata
包含模式信息。__tablename__
: 定义表名。Mapped[]
: Python 类型注解。mapped_column()
: 定义列属性 (SQL 类型, 约束)。ForeignKey("t.c")
在此定义。__abstract__ = True
: 定义不映射到表的抽象模型基类。代码示例: Base 和 Models 定义
# sqlalchemy_orm_practice/models/base.py
"""
SQLAlchemy ORM 基础模型模块
此模块定义了应用程序中所有 ORM 模型的基类和 Mixin 类。
主要组件:
- DeclarativeBase: 所有 ORM 模型的基类
- 各种可复用的 Mixin 类,如 TimestampMixin(提供时间戳功能)
"""
# -----Mapped 与 mapped_column 是做什么用的? ----------
# Mapped[] (例如 Mapped[int], Mapped[Optional[datetime]]):
# 类型注解工具:这是 SQLAlchemy 2.0 引入的,主要用于类型提示。它告诉开发者和其他工具(如静态类型检查器 Mypy)这个类属性是一个被 ORM 映射的字段,并且它在 Python 代码中的期望类型是什么。
# 提高可读性:让模型定义更清晰易懂。
# mapped_column(...) (例如 mapped_column(Integer, primary_key=True)):
# 列定义函数:这是实际用来定义数据库列属性的函数。
# 您可以在这里指定该列在数据库中的具体 SQL 类型(如 Integer, String(50), TIMESTAMP)、是否为主键 (primary_key=True)、
# 是否允许为空 (nullable=False)、默认值 (default=... 或 server_default=func.now())、索引、外键 (ForeignKey(...))、注释等。
# 它是 SQLAlchemy 2.0 中声明模型列的首选方式,用于替代旧版本中直接使用 Column(...) 的方式。
from sqlalchemy import Integer, DateTime, Boolean, TIMESTAMP, func, MetaData
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from datetime import datetime
from typing import Optional
from utils.print_utils import print_success
# --- 定义 Declarative Base ---
# DeclarativeBase 继承后会得到什么?
# 1.声明式能力:继承 DeclarativeBase 后,基类就获得了将 Python 类定义直接映射到数据库表的能力
# 2.metadata 属性:这个 Base 类会自动拥有一个名为 metadata 的 MetaData 对象。所有继承自 Base 的模型类,其表结构信息都会自动注册到这个 Base.metadata 中。
# DDL 操作基础:使得后续可以通过 Base.metadata.create_all(engine) 来自动创建所有定义的数据库表。
class Base(DeclarativeBase):
"""所有ORM 模型的基类"""
# 可选:可以定义一个 MetaData 对象来控制 ORM 行为
# metadata = MetaData()
pass
# --- 定义 Mixin (用于共享字段逻辑) ---
class TimestampMixin:
"""提供时间戳功能的 Mixin 类"""
# 使用 server_default 和 onupdate 选项,可以 利用数据库能力自动管理时间戳
create_time: Mapped[Optional[datetime]] = mapped_column(
TIMESTAMP, server_default=func.now(), comment="创建时间"
)
update_time: Mapped[Optional[datetime]] = mapped_column(
TIMESTAMP, server_default=func.now(), onupdate=func.now(), comment="更新时间"
)
# --- 定义抽象基类 (可选,用于共享 ID 等基础字段) ---
class AbstractBaseModel(Base):
"""包含 通用 ID 和 逻辑删除标记的抽象基类"""
__abstract__ = True # 不会创建对应的表
id: Mapped[int] = mapped_column(
Integer, primary_key=True, autoincrement=True, comment="主键ID"
)
# Mysql 中 常用 TINYINT(1) 来存储布尔值,sqlalchemy 中可以用 Boolean 类型 他会自动映射为 TINYINT(1)
is_deleted: Mapped[bool] = mapped_column(Boolean, server_default="0", index=True, comment="逻辑删除标记")
print_success("ORM Base, Mixin, AbstractBaseModel 已定义 (models/base.py)。")
在模型间建立关联,映射数据库外键。
文件: sqlalchemy_orm_practice/models/orm_models.py
(在模型类中添加/修改 relationship
定义)
知识点:
relationship("TargetModel", back_populates="attr", lazy="...")
back_populates
: 必须用于双向关系,值是对方模型的关系属性名。lazy
(加载策略): 'select'
(默认, 易 N+1), 'joined'
(JOIN, 适合 to-one), 'selectin'
(推荐 to-many), 'subquery'
, 'noload'
, 'raise'
。可在查询时用 options()
覆盖。secondary
: 指定多对多关联表。cascade
: 控制级联操作 (save-update
, delete
, delete-orphan
等)。# sqlalchemy_orm_practice/models/orm_models.py
"""
SQLAlchemy ORM 模型定义模块
此模块包含具体的 ORM 模型类定义,用于映射到数据库表。
主要模型:
- CategoryORM: 映射到 categories 表的分类模型
- ProductORM: 映射到 products 表的产品模型,与 CategoryORM 有关联关系
"""
from sqlalchemy import String, Integer, Numeric, ForeignKey, Boolean
from sqlalchemy.orm import Mapped, mapped_column, relationship # 导入 relationship
from typing import Optional, List
from .base import AbstractBaseModel, TimestampMixin, Base # 从同目录的 base.py 导入
from utils.print_utils import print_success # 绝对导入
# --- 定义具体模型 ---
class CategoryORM(AbstractBaseModel, TimestampMixin): # 继承 AbstractBaseModel 和 TimestampMixin
"""产品类型 ORM 模型"""
__tablename__ = "categories_orm"
name: Mapped[str] = mapped_column(
String(100), nullable=False, index=True, unique=True, comment="类别名称"
)
description: Mapped[Optional[str]] = mapped_column(
String(255), nullable=True, comment="类别描述"
)
# --- 添加关系 (见下一节) ---
products: Mapped[List["ProductORM"]] = relationship(
"ProductORM", # 目标模型类名
back_populates="category", # 指向 ProductORM 类的 category 属性
cascade="all, delete-orphan", # 级联操作: 删除 CategoryORM 时同时删除关联的 ProductORM
lazy="selectin" # 延迟加载,仅在需要时才加载数据
)
def __repr__(self):
"""打印对象信息"""
return f"{self.id}, name='{self.name}')>"
class ProductORM(AbstractBaseModel, TimestampMixin):
"""产品 ORM 模型"""
__tablename__ = "products_orm"
name: Mapped[str] = mapped_column(
String(100), index=True, comment="产品名称",
)
# 使用 Mapped[float] 作为 Python 类型提示,数据库类型由 Numeric 指定
price: Mapped[float] = mapped_column(
Numeric(10, 2), nullable=False, comment="产品价格"
)
# default 和 server_default 用于设置默认值,server_default 用于设置数据库默认值
stock: Mapped[int] = mapped_column(
Integer, default=0, server_default="0", nullable=False, comment="库存数量"
)
is_available: Mapped[bool] = mapped_column(
Boolean, default=True, server_default="1", nullable=False, comment="是否上架"
)
# --- 外键定义 ---
category_id: Mapped[Optional[int]] = mapped_column(
Integer, ForeignKey("categories_orm.id", ondelete="SET NULL"), # 关联 categories_orm 表的 id 列,删除时设置为 NULL
nullable=True, index=True, comment="类别所属ID"
)
# --- 添加关系 ---
category: Mapped[Optional["CategoryORM"]] = relationship(
"CategoryORM", back_populates="products", lazy="joined"
)
def __repr__(self):
return f"{self.id}, name='{self.name}', price={self.price})>"
print_success("具体模型 CategoryORM, ProductORM (含关系占位) 已定义 (models/orm_models.py)。")
# --- 创建表的函数 (通常在应用启动或 main 脚本中调用) ---
# def create_orm_tables(engine):
# print_info("尝试创建所有 ORM 模型对应的表...")
# try:
# Base.metadata.create_all(bind=engine)
# print_success("ORM 表已创建 (如果尚不存在)。")
# except Exception as e:
# print_error(f"创建 ORM 表时出错: {e}")
(2.0+) 主要通过 session.execute()
结合 select()
, update()
, delete()
语句,或直接操作 Session 管理的对象。
sqlalchemy_orm_practice/crud/orm_crud_ops.py
# sqlalchemy_orm_practice/crud/orm_crud_ops.py
"""
SQLAlchemy ORM CRUD 操作模块
此模块演示了 SQLAlchemy ORM 的基本 CRUD(创建、读取、更新、删除)操作。
主要功能:
- 创建: 创建新的类别和产品记录
- 读取: 按 ID、条件筛选和排序查询记录
- 更新: 更新现有记录的属性
- 删除: 从数据库中删除记录
"""
from sqlalchemy import select, update, delete, func # 导入 SQL 语句构造模块
from sqlalchemy.exc import SQLAlchemyError # 导入 SQLAlchemy 异常处理模块
from sqlalchemy.orm import Session # 导入 ORM 会话模块
from core_config import SessionLocal # 导入 ORM 会话工厂
from models.orm_models import CategoryORM, ProductORM # 导入 ORM 模型定义
from utils.print_utils import print_header, print_subheader, print_info, print_success, print_error, print_sql, \
print_warning, \
print_result_item
from typing import Optional
class OrmCrudOps:
"""封装 ORM CRUD 示例的类"""
def __init__(self, session: Session):
"""依赖注入 Session 对象"""
self.session = session
self.last_category_id: Optional[int] = None # 记录最后一个插入的类别 ID
self.last_product_id: Optional[int] = None # 记录最后一个插入的产品 ID
###### 创建对象 (INSERT) ######
def create_objects(self) -> tuple[Optional[int], Optional[int]]:
"""创建Category和Product对象并插入到数据库中"""
print_subheader("1. 创建 ORM 对象 (INSERT)")
try:
with self.session.begin_nested(): # 开启事务
# 创建 Category 对象
category = CategoryORM(name="书籍类", description="书籍类产品...包含各类图书")
self.session.add(category) # 将 Category 对象添加到会话中
self.session.flush() # 刷新会话,以便获取新插入对象的 ID
self.last_category_id = category.id # 记录最后一个插入的类别 ID
# 创建 Product 并关联
product = ProductORM(name="Python 编程指南", price=39.99, category_id=category.id)
self.session.add(product) # 将 Product 对象添加到会话中
self.session.flush() # 刷新会话,以便获取新插入对象的 ID
self.last_product_id = product.id # 记录最后一个插入的产品 ID
print_success("创建 ORM 对象成功。")
return self.last_category_id, self.last_product_id
except SQLAlchemyError as e:
print_error(f"创建 ORM 对象失败: {e}")
return None, None
##### 查询对象 (SELECT) #####
def query_object(self):
print_subheader("2. 查询 ORM 对象 (SELECT)")
try:
# 按 主键查询(session.get)
print_info(f"\n使用 session.get() 查询 Category ID={self.last_category_id}:")
category = self.session.get(CategoryORM, self.last_category_id)
if category:
print_success(f"查询结果: {category}")
else:
print_warning(f"查询结果: 未找到 ID={self.last_category_id} 的 Category 对象。")
# 条件查询(获取第一个)
print_info("\n条件查询 Product (name like '%Python%'):")
stmt_find = select(ProductORM).where(ProductORM.name.like("%Python%")).limit(1)
product = self.session.execute(stmt_find).scalars().first()
if product:
print_success(f"查询结果: {product}")
else:
print_warning("查询结果: 未找到符合条件的 Product 对象。")
except SQLAlchemyError as e:
print_error(f"查询失败: {e}")
##### 更新对象 (UPDATE) #####
def update_object(self):
print_subheader("3. 更新 ORM 对象 (UPDATE)")
try:
with self.session.begin_nested(): # 开启事务
# 方法一: 修改对象属性
print_info(f"\n修改 Product ID={self.last_product_id}的价格")
product_to_update = self.session.get(ProductORM,self.last_product_id)
if product_to_update:
product_to_update.price = 88.88 # 直接修改属性
else:print_warning(f"查询结果: 未找到 ID={self.last_product_id} 的 Product 对象。")
# 方法二: ORM 级 UPDATE语句 (批量)
print_info("\n批量降低 '书籍类' 的库存")
stmt_bulk_update = update(ProductORM).where(
ProductORM.category.has(CategoryORM.name == "书籍类")
).values(
stock = ProductORM.stock - 5 # 批量修改库存
).execution_options(synchronize_session="fetch") # 重要: 更新后刷新 session 中的对象
# synchronize_session='fetch'/'evaluate'/False
# 'fetch': 执行 UPDATE 后,SELECT 受影响行以更新 Session (最安全但可能慢)
# 'evaluate': 尝试在 Python 中评估更新效果 (快,但可能不精确)
# False: 不更新 Session 中的对象状态 (最快,但 Session 可能与 DB 不一致)
result = self.session.execute(stmt_bulk_update)
print_success(f"受影响的行数: {result.rowcount}")
print_success("更新操作已暂存 (待外层 Commit)。")
except SQLAlchemyError as e:print_error(f"更新失败: {e}")
###### 删除对象 (DELETE) ######
def delete_objects(self):
print_subheader("4. 删除 ORM 对象 (DELETE)")
try:
with self.session.begin_nested(): # 事务
# 方法一: 删除单个对象
print_info(f"\n删除 Product ID={self.last_product_id}:")
product_to_del = self.session.get(ProductORM, self.last_product_id)
if product_to_del:
self.session.delete(product_to_del) # 标记删除
print_success(f" ID={self.last_product_id} 已标记删除。")
else: print_warning(f" 未找到 ID={self.last_product_id}。")
# 方法二: ORM 级 DELETE 语句 (批量)
print_info("\n批量删除类别为 NULL 的产品:")
stmt_bulk_del = delete(ProductORM).where(ProductORM.category_id == None)
result = self.session.execute(stmt_bulk_del)
print_success(f" 批量删除影响了 {result.rowcount} 行。")
print_success(" 删除操作已暂存 (待外层 Commit)。")
except Exception as e:
print_error(f"删除对象时出错: {e}")
raise
# --- 在 main_orm_runner.py 中如何使用 ---
# with SessionLocal() as main_session:
# crud_executor = OrmCrudOps(main_session)
# try:
# crud_executor.create_objects()
# crud_executor.query_objects()
# crud_executor.update_objects()
# # crud_executor.delete_objects() # 决定是否执行删除
# main_session.commit() # 提交所有操作
# print_success("\n主事务已提交!")
# except Exception as main_err:
# print_error(f"\n主流程出错: {main_err}")
# main_session.rollback()
# print_info("主事务已回滚。")
"""
SQLAlchemy ORM 主运行脚本
此脚本作为项目的入口点,创建数据库表并运行各个示例模块的功能演示。
主要功能:
- 初始化数据库: 创建所有所需的表
- 运行示例: 按顺序调用各个模块的示例功能
- 演示工作流: 展示完整的 ORM 使用流程和最佳实践
"""
from sqlalchemy import func, inspect
from utils.print_utils import print_header, print_info, print_success, print_error, print_warning
from core_config import SessionLocal, orm_engine
from models.base import Base
from models.orm_models import CategoryORM, ProductORM
from crud.orm_crud_ops import OrmCrudOps
def create_tables():
"""创建数据库表"""
print_header("创建数据库表")
try:
# 创建所有模型对应的表
Base.metadata.create_all(bind=orm_engine)
print_success("所有 ORM 表已创建 (如果尚不存在)。")
# 打印创建的表信息
inspector = inspect(orm_engine)
tables = inspector.get_table_names()
print_info(f"数据库中的表: {tables}")
except Exception as e:
print_error(f"创建表时出错: {e}")
raise
def run_crud_examples():
"""运行 CRUD 操作示例"""
print_header("运行 ORM CRUD 操作示例")
# 使用会话上下文管理器
with SessionLocal() as session:
# 创建 CRUD 操作执行器
crud_executor = OrmCrudOps(session)
try:
# 1. 创建对象
category_id, product_id = crud_executor.create_objects()
if not category_id or not product_id:
print_warning("创建对象失败,跳过后续操作。")
return
# 2. 查询对象
crud_executor.query_object()
# 3. 更新对象
crud_executor.update_object()
# 4. 删除对象 (可选)
# 取消注释下一行以执行删除操作
# crud_executor.delete_objects()
# 提交所有操作
session.commit()
print_success("\n所有操作已成功提交!")
except Exception as e:
print_error(f"\n执行 CRUD 操作时出错: {e}")
session.rollback()
print_info("所有操作已回滚。")
if __name__ == "__main__":
# 创建表结构
# create_tables()
# 运行 CRUD 示例
run_crud_examples()
print_header("程序执行完毕")
mybatis-py
是一个为 Python 开发者设计的轻量级 SQL 映射框架,其核心理念与 Java 领域知名的 MyBatis 框架相近。它旨在提供一种方式,让开发者能够更直接地控制 SQL 语句的编写和执行,同时又通过映射机制简化 Python 代码与数据库之间的交互。对于既想利用 SQL 的全部能力进行性能调优或处理复杂逻辑,又希望避免直接操作原生数据库驱动(如 pymysql
)时繁琐的模板代码的场景,mybatis-py
是一个值得考虑的工具。
核心定位: 半自动化 ORM / SQL 映射器 (强调 SQL 控制权)。
主要功能特性概览:
特性编号 | 功能描述 | 详细说明与开发者价值 |
---|---|---|
1 | 半自动化的 ORM | 提供 Python 方法到 SQL 语句的映射,以及结果集到 Python 字典或简单对象的转换,简化数据访问代码,但开发者仍需编写 SQL。 |
2 | 支持动态 SQL | 核心特性之一。允许在 XML Mapper 文件中使用 , , , 等标签,根据传入参数动态构建和调整 SQL 语句,实现复杂查询逻辑。 |
3 | 装饰器 API | 提供了类似 MyBatis 注解的 Python 装饰器 (@mb.SelectOne , @mb.Insert 等),可直接在 Python 方法上绑定 SQL 语句,适用于简单、固定的 SQL 操作。 |
4 | LRU 缓存及过期机制 | 内置基于 LRU (Least Recently Used) 算法的查询缓存,可配置缓存池大小和条目过期时间,对不经常变动但查询频繁的数据能有效提升性能。 |
5 | Prepared Statement 支持 | 当使用 #{placeholder} 语法时,优先使用预编译语句,将 SQL 结构与数据分离,这是防御 SQL 注入攻击的关键手段。 |
6 | 预防大对象机制 (OOM) | 内置机制(通过 max_result_bytes 参数)限制从数据库拉取并处理的数据总量,旨在避免因查询结果集过大导致的应用程序内存溢出问题。 |
7 | 多数据库支持 | 目前明确支持 MySQL 和 PostgreSQL,可以通过 ConnectionFactory 指定不同的 dbms_name 。 |
**安装 mybatis-py
**
pip install -U mybatis
mybatis
: mybatis-py
库。数据库及表示例准备 (MySQL):
(确保 MySQL 服务已启动,并已创建相应数据库、用户及授权。以下 SQL 用于创建本节演示所需的库和表。)
-- 1. 创建数据库 (如果不存在),使用 utf8mb4 字符集以支持更广泛的字符
CREATE DATABASE IF NOT EXISTS mybatis_py_complete_demo CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
-- 2. 创建用户并授权 (请务必将 'your_strong_password' 替换为强密码)
CREATE USER IF NOT EXISTS 'mb_user_complete'@'localhost' IDENTIFIED BY 'your_strong_password';
GRANT ALL PRIVILEGES ON mybatis_py_complete_demo.* TO 'mb_user_complete'@'localhost';
FLUSH PRIVILEGES; -- 刷新权限使之立即生效
-- 3. 切换到目标数据库
USE mybatis_py_complete_demo;
-- 4. 创建示例表 'fruit_categories' (水果类别表)
CREATE TABLE IF NOT EXISTS fruit_categories (
id INT AUTO_INCREMENT PRIMARY KEY COMMENT '类别ID (主键)',
name VARCHAR(100) NOT NULL UNIQUE COMMENT '类别名称 (唯一)',
description TEXT COMMENT '类别描述'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='水果类别表';
-- 5. 创建示例表 'fruits' (水果信息表),并设置外键关联
CREATE TABLE IF NOT EXISTS fruits (
id INT AUTO_INCREMENT PRIMARY KEY COMMENT '水果ID (主键)',
name VARCHAR(100) NOT NULL COMMENT '水果名称',
category_id INT COMMENT '类别ID (外键), 允许为NULL表示未分类',
price INT COMMENT '价格 (单位:分,使用整数存储以避免浮点精度问题)',
description TEXT COMMENT '水果描述 (可选)',
FOREIGN KEY (category_id) REFERENCES fruit_categories(id)
ON DELETE SET NULL -- 如果关联的类别被删除,此水果的 category_id 设为 NULL
ON UPDATE CASCADE -- 如果关联的类别的 id 更新,此水果的 category_id 自动更新
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='水果信息表';
-- 6. 清空表数据并插入初始记录 (方便重复运行示例)
SET FOREIGN_KEY_CHECKS=0; -- 临时禁用外键检查,方便清空
TRUNCATE TABLE fruits;
TRUNCATE TABLE fruit_categories;
SET FOREIGN_KEY_CHECKS=1; -- 重新启用外键检查
INSERT INTO fruit_categories (name, description) VALUES
('温带水果', '如苹果、梨等'),
('热带水果', '如香蕉、芒果、菠萝等'),
('浆果类', '如草莓、蓝莓、树莓等');
INSERT INTO fruits (name, category_id, price, description) VALUES
('红富士苹果', (SELECT id FROM fruit_categories WHERE name='温带水果'), 750, '脆甜多汁的红富士苹果'),
('香芽蕉', (SELECT id FROM fruit_categories WHERE name='热带水果'), 420, '口感软糯的香芽蕉'),
('奶油草莓', (SELECT id FROM fruit_categories WHERE name='浆果类'), 1200, '大颗香甜的奶油草莓');
建议项目结构 (用于 mybatis-py
示例):
db_framework_practice/
├── examples/
│ ├── mybatis_py_v_final/ # 为当前完整重写版本创建新目录
│ │ ├── __init__.py
│ │ ├── config.py # 数据库连接参数
│ │ ├── ex01_decorators_main.py # 装饰器使用示例主程序
│ │ ├── fruit_repo_deco.py # 使用装饰器的水果仓库类
│ │ ├── mappers/ # 存放 XML Mapper 文件
│ │ │ ├── __init__.py # 使 mappers 成为一个包
│ │ │ └── fruits_mapper.xml # 使用动态sql的核心模块
│ │ └── ex02_xml_mapper_main.py # XML Mapper 使用示例主程序
│ │ └── ex03_flask_app_main.py # Flask 集成示例 (在后续提供)
├── utils/
│ ├── __init__.py
│ └── print_utils.py # 打印工具模块
└── ...
数据库配置文件: examples/mybatis_py_v_final/config.py
"""
数据库连接配置模块
此模块包含数据库连接的配置参数,用于建立与数据库的连接。
"""
# 数据库连接参数
DB_CONFIG = {
'dbms_name': 'mysql', # change to 'postgresql' if you are using PostgreSQL
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': 'root',
'database': 'mybatis_demo',
'charset': 'utf8mb4'
}
# Mybatis 实例通用配置
MYBATIS_CONFIG = {
"cache_memory_limit": 10 * 1024 * 1024, # 10MB 缓存
"cache_max_live_ms": 5 * 60 * 1000, # 5 分钟缓存有效期 (毫秒)
"max_result_bytes": 50 * 1024 * 1024 # 50MB 最大结果集字节数 (用于 select_many)
}
Mybatis
类核心 API (基于提供的源码分析)
以下是对 mybatis-py
源码中 Mybatis
类的核心构造函数、主要方法及其参数的总结。
Mybatis.__init__(...)
构造函数参数
参数 | 类型 | 描述 | 源码默认/说明 |
---|---|---|---|
conn |
AbstractConnection |
必需. 已建立的数据库连接对象 (通常通过 ConnectionFactory.get_connection() 获取)。 |
- |
mapper_path |
str |
必需. XML Mapper 文件所在的目录路径或 Python 包路径 (例如 "mappers" 或 "your_package.mappers" )。 |
库会在此路径下查找并加载所有 .xml 后缀的 Mapper 文件。 |
cache_memory_limit |
Optional[int] |
可选。缓存的内存限制 (字节)。如果为 None ,源码中 Cache 对象以 0 初始化,可能表示不启用或使用不同逻辑。 |
默认: None (源码中 Cache(0,…)) |
cache_max_live_ms |
int |
可选。缓存条目的最大存活时间 (毫秒)。 | 默认: 5 * 1000 (即 5 秒) |
max_result_bytes |
int |
可选。select_many 方法返回的结果列表允许占用的最大总字节数,用于防止 OOM。 |
默认: 100 * 1024 * 1024 (即 100MB) |
XML Mapper 调用方法 (Mybatis 实例方法)
方法 | 参数 | 返回类型 | 用途说明 |
---|---|---|---|
select_one(id: str, params: dict) |
id (XML中语句的 namespace.id 或全局唯一 id ) params (传递给SQL的参数字典) |
Optional[Dict] |
执行 语句,预期返回单条记录 (字典) 或 None 。会使用缓存。 |
select_many(id: str, params: dict) |
id (XML中语句的 namespace.id 或全局唯一 id ) params (传递给SQL的参数字典) |
Optional[List[Dict]] |
执行 语句,返回多条记录 (字典列表) 或 None (若无结果)。会使用缓存及 max_result_bytes 限制。 |
update(id: str, params: dict) |
id (XML中语句的 namespace.id 或全局唯一 id ) params (传递给SQL的参数字典) |
int |
执行 语句,返回受影响的行数 (cursor.rowcount() )。会清空整个缓存。 |
delete(id: str, params: dict) |
id (XML中语句的 namespace.id 或全局唯一 id ) params (传递给SQL的参数字典) |
int |
执行 语句,返回受影响的行数 (cursor.rowcount() )。会清空整个缓存。 |
insert(id: str, params: dict, primary_key: str = None) |
id (XML中语句的 namespace.id 或全局唯一 id ) params (参数字典) primary_key (可选,用于 PostgreSQL 的 RETURNING 子句,指定主键列名) |
int |
执行 语句。返回值是 cursor.lastrowid() (通常用于获取 MySQL 自增ID)。会清空整个缓存。params 字典可能被 XML 中的 keyProperty 修改。 |
装饰器方法 (@mb.DecoratorName
)
装饰器方法 (@mb. ) |
装饰器参数 | 装饰的函数签名 (示例) | 用途及内部行为说明 |
---|---|---|---|
SelectOne(unparsed_sql) |
unparsed_sql (原始SQL字符串) |
def func(**kwargs) -> Optional[Dict] |
将函数映射到给定的 SELECT SQL,预期返回单条记录。kwargs 作为参数传递给 SQL (#{key} 占位符)。内部处理 SQL 解析、参数绑定、缓存、结果转换。 |
SelectMany(unparsed_sql) |
unparsed_sql (原始SQL字符串) |
def func(**kwargs) -> Optional[List[Dict]] |
将函数映射到给定的 SELECT SQL,返回多条记录。处理同上,并应用 max_result_bytes 。 |
Insert(unparsed_sql, primary_key=None) |
unparsed_sql (SQL), primary_key (主键列名) |
def func(**kwargs) -> int (通常返回 lastrowid ) |
将函数映射到 INSERT SQL。primary_key 主要用于 PostgreSQL 的 RETURNING 。MySQL 中通常返回 lastrowid 。清空缓存。 |
Update(unparsed_sql) |
unparsed_sql (原始SQL字符串) |
def func(**kwargs) -> int (返回受影响行数) |
将函数映射到 UPDATE SQL,返回受影响行数。清空缓存。 |
Delete(unparsed_sql) |
unparsed_sql (原始SQL字符串) |
def func(**kwargs) -> int (返回受影响行数) |
将函数映射到 DELETE SQL,返回受影响行数。清空缓存。 |
Workspace_rows(cursor, batch_size=1000)
辅助函数 (内部):
Mybatis.select_many
和装饰器 SelectMany
内部使用,用于分批从数据库游标获取数据。yield
以生成器方式逐条返回字典,这使得调用方 (如 select_many
) 可以在迭代过程中检查 max_result_bytes
限制。缓存键 (CacheKey(sql, param_list)
):
sql
) 和实际绑定的参数列表 (param_list
) 生成。方法一: 使用装饰器 API
适用于 SQL 语句相对固定且逻辑简单的场景。
db_framework_practice/examples/mybatis_py_v2/ex01_decorators_usage.py
(封装操作)"""
装饰器使用示例模块
# db_framework_practice/examples/mybatis_py_v2/ex01_decorators_usage.py
此模块演示了如何使用MyBatis-Py V2的装饰器功能进行数据库操作。
"""
from mybatis import Mybatis, ConnectionFactory
from config import DB_CONFIG
from typing import List, Dict, Optional, Any
from utils.print_utils import *
class FruitRepository:
"""水果数据仓库类,封装对水果表的所有数据库操作"""
def __init__(self, mappers_dir="mappers", cache_size=50*1024*1024):
"""
初始化水果仓库
Args:
mappers_dir: XML映射文件目录
cache_size: 缓存大小,默认50MB
"""
self.connection_factory = ConnectionFactory.get_connection(**DB_CONFIG)
self.mybatis = Mybatis(self.connection_factory, mappers_dir, cache_memory_limit=cache_size)
self._init_operations()
def _init_operations(self):
"""初始化数据库操作方法"""
@self.mybatis.SelectOne("SELECT * FROM fruits WHERE id=#{id}")
def get_one(id: int) -> Optional[Dict[str, Any]]:
"""获取一个水果记录,返回一个字典"""
pass
@self.mybatis.SelectMany("SELECT * FROM fruits WHERE category_id = #{category_id}")
def get_many(category_id: int) -> List[Dict[str, Any]]:
"""获取多个水果记录,返回一个列表"""
pass
@self.mybatis.Insert(
"INSERT INTO fruits (name, category_id, price, description) VALUES (#{name}, #{category_id}, #{price}, #{description})",
primary_key="id")
def insert(name: str, category_id: int, price: float, description: str) -> int:
"""插入水果记录,返回插入的ID"""
pass
@self.mybatis.Delete("DELETE FROM fruits WHERE id = #{id}")
def delete(id: int) -> int:
"""删除水果记录,返回删除的行数"""
pass
@self.mybatis.Update("UPDATE fruits SET name=#{name}, category_id=#{category_id}, price=#{price}, description=#{description} WHERE id=#{id}")
def update(id: int, name: str, category_id: int, price: float, description: str) -> int:
"""
更新水果记录,返回更新的行数
"""
pass
# 将所有操作绑定到实例上
self.get_one = get_one
self.get_many = get_many
self.insert = insert
self.delete = delete
self.update = update
def close(self):
"""关闭数据库连接"""
if hasattr(self, 'connection_factory') and self.connection_factory:
self.connection_factory.close()
def main():
"""主函数,用于测试FruitRepository类的功能"""
print_header("Mybatis-Py: 装饰器 API 使用示例 (面向对象版)")
# 创建仓库实例
repo = FruitRepository()
try:
# 测试查询功能
print_info("1. 测试查询单个水果")
fruit = repo.get_one(id=1)
if fruit:
print_success(f"查询到水果: {fruit['name']}, 价格: {fruit['price']}")
else:
print_warning("未找到ID为1的水果")
# 测试查询多个水果
print_info("\n2. 测试查询某分类下的所有水果")
category_id = 1
fruits = repo.get_many(category_id=category_id)
print_success(f"分类 {category_id} 下有 {len(fruits)} 个水果:")
for fruit in fruits:
print_info(f" - {fruit['name']}: ¥{fruit['price']}")
# 测试插入操作
print_info("\n3. 测试插入新水果")
new_fruit = {
"name": "蓝莓",
"category_id": 2,
"price": 25.5,
"description": "新鲜蓝莓,富含抗氧化物质"
}
new_id = repo.insert(**new_fruit)
print_success(f"插入成功,新水果ID: {new_id}")
# 测试更新操作
print_info("\n4. 测试更新水果")
update_fruit = {
"id": new_id, # 使用刚插入的ID
"name": "有机蓝莓",
"category_id": 2,
"price": 28.5,
"description": "有机认证蓝莓,无农药"
}
affected_rows = repo.update(**update_fruit)
print_success(f"更新了 {affected_rows} 行记录")
# 验证更新结果
updated_fruit = repo.get_one(id=new_id)
if updated_fruit:
print_success(f"更新后的水果: {updated_fruit['name']}, 价格: {updated_fruit['price']}")
# 测试删除操作
print_info("\n5. 测试删除水果")
deleted_rows = repo.delete(id=new_id)
print_success(f"删除了 {deleted_rows} 行记录")
except Exception as e:
print_error(f"操作失败: {str(e)}")
finally:
# 关闭连接
repo.close()
print_info("\n测试完成,已关闭数据库连接")
if __name__ == "__main__":
main()
方法二: 使用 XML Mapper 文件
对于包含动态逻辑或较为复杂的 SQL 语句,XML Mapper 文件提供了更强大的表达能力。
1. XML Mapper 文件 (ddb_framework_practice/examples/mybatis_py_v2/mappers/fruits_mapper.xml
)
DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="fruits">
<insert id="insertFruit" useGeneratedKeys="true" keyProperty="generated_id_xml" keyColumn="id">
INSERT INTO fruits (name, category_id, price, description)
VALUES (
#{name},
#{category_id},
#{price},
#{description}
)
insert>
<delete id="deleteFruitById">
DELETE FROM fruits WHERE id = #{id}
delete>
<update id="updateFruit">
UPDATE fruits
SET
name = #{name},
category_id = #{category_id},
price = #{price},
description = #{description}
WHERE id = #{id}
update>
<select id="findFruitDetailsById" resultType="dict">
SELECT
f.id,
f.name,
f.category_id,
fc.name as category_name,
f.price,
f.description
FROM fruits as f
LEFT JOIN FRUIT_CATEGORIES FC ON
f.category_id = fc.id
WHERE f.id = #{id}
select>
<select id="findFruitsByCriteria" resultType="dict">
SELECT
f.id,
f.name,
f.category_id,
fc.name as category_name,
f.price,
f.description
FROM fruits as f
LEFT JOIN fruit_categories FC ON
f.category_id = fc.id
<where>
<if test="'name' in params">
f.name LIKE CONCAT('%', #{name}, '%')
if>
where>
select>
mapper>
2. Python 代码 (db_framework_practice/examples/mybatis_py_v2/ex02_xml_mapper_usage.py
)
# db_framework_practice/examples/mybatis_py_v2/ex02_xml_mapper_usage.py
from mybatis import Mybatis, ConnectionFactory
from typing import List, Dict, Optional, Any
from utils.print_utils import *
from config import DB_CONFIG
class FruitRepoXmlOperational:
"""使用 XML Mapper 操作水果数据的数据仓库类"""
def __init__(self, mappers_dir="mappers", cache_size=0):
"""
初始化水果仓库
Args:
mappers_dir: XML映射文件目录
cache_size: 缓存大小,默认50MB
"""
self.connection_factory = ConnectionFactory.get_connection(**DB_CONFIG)
self.mybatis = Mybatis(self.connection_factory, mappers_dir, cache_memory_limit=cache_size)
def insert_fruit(self, fruit: Dict[str, Any]) -> int:
"""插入水果数据(使用xml mapper)"""
return self.mybatis.insert("fruits.insertFruit", fruit)
def delete_fruit(self, fruit_id: int) -> int:
"""删除水果数据(使用xml mapper)"""
# 将单个ID参数转换为字典格式,以符合mybatis.delete的参数要求
delete_params = {"id": fruit_id}
return self.mybatis.delete("fruits.deleteFruitById", delete_params)
def update_fruit(self, fruit: Dict[str, Any]) -> int:
"""更新水果数据(使用xml mapper)"""
result = self.mybatis.update("fruits.updateFruit", fruit)
if result > 0:
print_success(f"更新水果数据成功,影响行数: {result}")
else:
print_error(f"更新水果数据失败,影响行数: {result}")
return result
def find_fruit_details_by_id(self, fruit_id: int) -> Dict[str, Any]:
"""根据ID查询水果详细信息(使用xml mapper)"""
return self.mybatis.select_one("fruits.findFruitDetailsById", {"id": fruit_id})
def find_fruits_by_criteria(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
"""根据条件查询水果列表(使用xml mapper)
Args:
params: 查询参数字典,可包含以下键:
- category_id: 分类ID
- min_price: 最低价格
- max_price: 最高价格
- name: 水果名称(支持模糊查询)
- sort_by: 排序字段,如'price'
- sort_order: 排序方向,如'DESC'
- limit: 返回记录数量限制
- offset: 分页偏移量
Returns:
符合条件的水果列表
"""
# 直接将params传递给mybatis,不再额外包装
return self.mybatis.select_many("fruits.findFruitsByCriteria", params)
if __name__ == '__main__':
repo = FruitRepoXmlOperational()
# 测试查询水果列表
print_header("测试查询水果列表")
params = {
"name": "苹",
}
fruits = repo.find_fruits_by_criteria(params)
for fruit in fruits:
print_info(f"水果ID: {fruit['id']}, 名称: {fruit['name']}, 价格: {fruit['price']}")
将 mybatis-py
集成到 Flask Web 应用中,提供 API 接口。
项目结构:
db_framework_practice/examples/mybatis_py_v_final/ex03_flask_integration.py
db_framework_practice/examples/mybatis_py_v_final/mappers/flask_fruits_ops.xml
(可复用或创建新的)db_framework_practice/examples/mybatis_py_v_final/config.py
db_framework_practice/utils/print_utils.py
XML Mapper (db_framework_practice/examples/mybatis_py_v_final/mappers/flask_fruits_ops.xml
):
DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="flask_api_fruits"> <select id="getAllFruitsForApi" resultType="dict">
SELECT f.id, f.name, fc.name AS category_name, f.price, f.description
FROM fruits f
LEFT JOIN fruit_categories fc ON f.category_id = fc.id
ORDER BY f.name
select>
<select id="getFruitByIdForApi" resultType="dict">
SELECT f.id, f.name, fc.name AS category_name, f.price, f.description
FROM fruits f
LEFT JOIN fruit_categories fc ON f.category_id = fc.id
WHERE f.id = #{fruit_id_param}
select>
<insert id="addNewFruitForApi" useGeneratedKeys="true" keyProperty="new_fruit_id" keyColumn="id">
INSERT INTO fruits (name, category_id, price, description)
VALUES (#{name}, #{category_id}, #{price}, #{description})
insert>
<update id="updateFruitForApi">
UPDATE fruits
<set>
<if test="'name' in params and params.name != null">name = #{name},if>
<if test="'category_id' in params and params.category_id != null">category_id = #{category_id},if>
<if test="'price' in params and params.price != null">price = #{price},if>
<if test="'description' in params">description = #{description},if>
set>
WHERE id = #{id}
update>
<delete id="deleteFruitForApi">
DELETE FROM fruits WHERE id = #{id}
delete>
mapper>
Flask 应用代码 (db_framework_practice/examples/mybatis_py_v_final/ex03_flask_integration.py
):
# db_framework_practice/examples/mybatis_py_v_final/ex03_flask_integration.py
from flask import Flask, jsonify, request
import pymysql # 用于捕获 pymysql.Error
import mybatis.errors # mybatis-py 错误类
from mybatis import Mybatis, ConnectionFactory, DictCursor # 导入 DictCursor
# import orjson # 可选的高性能 JSON
import functools
import os
import sys
from typing import Optional, Callable, Any
# --- 动态路径设置,确保能导入同级和父级模块 ---
current_file_dir_flask = os.path.dirname(os.path.abspath(__file__))
project_root_flask_app = os.path.abspath(os.path.join(current_file_dir_flask, '../../../'))
if project_root_flask_app not in sys.path:
sys.path.append(project_root_flask_app)
examples_root_flask_app = os.path.abspath(os.path.join(current_file_dir_flask, '../')) # mybatis_py_v_final 的父目录
if examples_root_flask_app not in sys.path: # 允许从 mybatis_py_v_final 导入 config
sys.path.insert(0, examples_root_flask_app) # 插入到最前面,优先查找
from utils.print_utils import print_header, print_info, print_success, print_error
from mybatis_py_v_final.config import DB_CONNECTION_PARAMS_V3, MYBATIS_INSTANCE_CONFIG
app = Flask(__name__)
print_header("Mybatis-Py Flask 集成示例 (V_FINAL - 完整版)")
# --- 数据库和 Mybatis 实例管理 (简化版,生产环境需更健壮) ---
# 全局变量在多线程/多进程环境中存在风险,Flask 有更好的上下文管理机制 (如 g 对象)
_MB_INSTANCE_FLASK_APP: Optional[Mybatis] = None
_DB_CONN_ERR_FLASK_APP: bool = False
_DB_ERR_MSG_FLASK_APP: str = ""
def get_mybatis_for_request() -> Optional[Mybatis]:
"""
获取或初始化用于当前请求的 Mybatis 实例。
警告: 此简单实现仍使用全局变量,未完全适配 Flask 请求生命周期。
在真实应用中,推荐使用 Flask 的 g 对象或 before_request/teardown_request。
"""
global _MB_INSTANCE_FLASK_APP, _DB_CONN_ERR_FLASK_APP, _DB_ERR_MSG_FLASK_APP
# 检查现有连接是否有效
if _MB_INSTANCE_FLASK_APP and _MB_INSTANCE_FLASK_APP.conn:
try:
_MB_INSTANCE_FLASK_APP.conn.ping(reconnect=True) # 尝试 ping,如果断开则重连
_MB_INSTANCE_FLASK_APP.conn.set_autocommit(False) # 确保 autocommit 状态
_DB_CONN_ERR_FLASK_APP = False # 如果 ping 成功,清除错误标记
# print_info("[FLASK_APP] Existing Mybatis connection is active.")
return _MB_INSTANCE_FLASK_APP
except Exception as ping_err:
print_warning(f"[FLASK_APP] Ping failed for existing connection: {ping_err}. Attempting new connection.")
_DB_CONN_ERR_FLASK_APP = True # 标记错误
_DB_ERR_MSG_FLASK_APP = str(ping_err)
try: _MB_INSTANCE_FLASK_APP.conn.close() # 关闭无效连接
except: pass
_MB_INSTANCE_FLASK_APP = None # 清除旧实例
# 如果实例不存在,或连接出错/丢失,则创建新实例
if _MB_INSTANCE_FLASK_APP is None or _DB_CONN_ERR_FLASK_APP:
print_info("[FLASK_APP] Creating new Mybatis instance for Flask request...")
try:
db_params_with_cursor = {**DB_CONNECTION_PARAMS_V3, 'cursorclass': DictCursor}
conn_req = ConnectionFactory.get_connection(**db_params_with_cursor)
conn_req.set_autocommit(False) # 非常重要:手动控制事务
# XML Mappers 目录路径 (相对于此 ex03_flask_integration.py 文件)
mapper_dir_for_flask = os.path.join(os.path.dirname(__file__), "mappers")
if not os.path.isdir(mapper_dir_for_flask):
os.makedirs(mapper_dir_for_flask, exist_ok=True)
print_warning(f"Mappers directory '{mapper_dir_for_flask}' for Flask was created.")
_MB_INSTANCE_FLASK_APP = Mybatis(conn_req, mapper_dir_for_flask, **MYBATIS_INSTANCE_CONFIG)
_DB_CONN_ERR_FLASK_APP = False
_DB_ERR_MSG_FLASK_APP = ""
print_success("[FLASK_APP] New Mybatis instance created successfully for Flask.")
except Exception as e:
_DB_CONN_ERR_FLASK_APP = True
_DB_ERR_MSG_FLASK_APP = str(e)
_MB_INSTANCE_FLASK_APP = None
print_error(f"[FLASK_APP] Failed to create Mybatis instance for Flask: {e}")
return _MB_INSTANCE_FLASK_APP
# --- 数据库操作装饰器 ---
def db_request_handler(commit_on_success: bool = False):
"""
Flask 请求装饰器: 统一处理 Mybatis 实例获取、事务和错误。
"""
def decorator(view_function: Callable[..., Any]):
@functools.wraps(view_function)
def wrapper(*args: Any, **kwargs: Any) -> Any:
mb_instance = get_mybatis_for_request()
if not mb_instance:
return jsonify(error="Database service temporarily unavailable.", details=_DB_ERR_MSG_FLASK_APP), 503
try:
response = view_function(mb_instance, *args, **kwargs) # 注入 Mybatis 实例
if commit_on_success:
mb_instance.conn.commit()
print_success(f"[FLASK_APP] Transaction committed for: {request.path} {request.method}")
return response
except mybatis.errors.MybatisError as me:
if mb_instance.conn: mb_instance.conn.rollback()
print_error(f"[FLASK_APP] MybatisError at {request.path}: {me}")
return jsonify(error="A database operation error occurred.", details=str(me)), 500
except pymysql.Error as pe: # 更具体的数据库驱动错误
if mb_instance.conn: mb_instance.conn.rollback()
global _DB_CONN_ERR_FLASK_APP, _DB_ERR_MSG_FLASK_APP
_DB_CONN_ERR_FLASK_APP = True
_DB_ERR_MSG_FLASK_APP = str(pe)
print_error(f"[FLASK_APP] PyMySQLError at {request.path}: {pe}")
return jsonify(error="A database communication error occurred.", details=str(pe)), 502
except Exception as e_global:
if mb_instance.conn: mb_instance.conn.rollback()
print_error(f"[FLASK_APP] Unexpected error at {request.path}: {e_global}")
import traceback
traceback.print_exc() # 打印完整堆栈到服务器日志
return jsonify(error="An unexpected internal server error occurred."), 500
return wrapper
return decorator
# --- Flask 路由定义 ---
@app.route('/api/v3/fruits', methods=['GET'])
@db_request_handler() # 默认 commit_on_success=False
def api_get_all_fruits_v3(mb: Mybatis):
"""API 端点: 获取所有水果"""
print_info(f"[FLASK_APP] Handling GET /api/v3/fruits")
# 假设 XML `flask_api_fruits.getAllFruitsForApi` 已定义
fruits = mb.select_many("flask_api_fruits.getAllFruitsForApi")
print_success(f"[FLASK_APP] Found {len(fruits)} fruits.")
return jsonify(fruits if fruits else []) # 确保空列表也是有效 JSON
@app.route('/api/v3/fruits/' , methods=['GET'])
@db_request_handler()
def api_get_fruit_by_id_v3(mb: Mybatis, fruit_id: int):
"""API 端点: 根据 ID 获取单个水果"""
print_info(f"[FLASK_APP] Handling GET /api/v3/fruits/{fruit_id}")
fruit = mb.select_one("flask_api_fruits.getFruitByIdForApi", {'fruit_id_param': fruit_id})
if fruit:
print_success(f"[FLASK_APP] Found fruit: {fruit.get('name')}")
return jsonify(fruit)
else:
print_warning(f"[FLASK_APP] Fruit ID {fruit_id} not found.")
return jsonify(error=f"Fruit with ID {fruit_id} not found."), 404
@app.route('/api/v3/fruits', methods=['POST'])
@db_request_handler(commit_on_success=True) # 写操作,成功后提交
def api_add_fruit_v3(mb: Mybatis):
"""API 端点: 添加新水果"""
print_info(f"[FLASK_APP] Handling POST /api/v3/fruits")
if not request.is_json:
return jsonify(error="Request body must be JSON."), 400
data = request.get_json()
required_fields = ['name', 'price', 'category_id'] # 假设这些是必需的
if not all(field in data for field in required_fields):
return jsonify(error=f"Missing required fields in JSON: {', '.join(required_fields)}."), 400
params = {
"name": data.get("name"),
"category_id": data.get("category_id"),
"price": data.get("price"),
"description": data.get("description") # 可选
}
mb.insert('flask_api_fruits.addNewFruitForApi', params)
new_id = params.get('new_fruit_id') # 尝试获取由 keyProperty 设置的 ID
if new_id is not None:
print_success(f"[FLASK_APP] Fruit '{params['name']}' added, new ID: {new_id}.")
return jsonify(message="Fruit added successfully.", id=new_id), 201
else:
# 对于 MySQL, 如果 keyProperty 行为不符合预期, lastrowid 可能在 cursor 上
# 但 Mybatis 封装了 cursor, 这里简化处理
print_warning("[FLASK_APP] Fruit added, but ID not retrieved via keyProperty in params.")
return jsonify(message="Fruit added (ID retrieval might vary)."), 201
@app.route('/api/v3/fruits/' , methods=['PUT'])
@db_request_handler(commit_on_success=True)
def api_update_fruit_v3(mb: Mybatis, fruit_id: int):
"""API 端点: 更新指定 ID 的水果信息"""
print_info(f"[FLASK_APP] Handling PUT /api/v3/fruits/{fruit_id}")
if not request.is_json:
return jsonify(error="Request body must be JSON."), 400
data = request.get_json()
if not data: # 确保至少有一些数据用于更新
return jsonify(error="No data provided for update."), 400
params = {
"id": fruit_id, # 对应 XML 中的 WHERE id = #{id}
"name": data.get("name"), # 对应 XML 中的 name = #{name}
"category_id": data.get("category_id"),
"price": data.get("price"),
"description": data.get("description")
}
# 清理掉值为 None 的参数,以便 XML 中的 能正确工作
update_params_clean = {k: v for k, v in params.items() if v is not None}
if len(update_params_clean) <= 1 and "id" in update_params_clean: # 至少要更新一个字段
return jsonify(error="No fields to update provided besides ID."), 400
affected_rows = mb.update("flask_api_fruits.updateFruitForApi", update_params_clean)
if affected_rows > 0:
print_success(f"[FLASK_APP] Fruit ID {fruit_id} updated, {affected_rows} row(s) affected.")
return jsonify(message=f"Fruit ID {fruit_id} updated successfully.")
else:
print_warning(f"[FLASK_APP] Fruit ID {fruit_id} not found or no changes made during update.")
return jsonify(error=f"Fruit ID {fruit_id} not found or no changes applied."), 404
@app.route('/api/v3/fruits/' , methods=['DELETE'])
@db_request_handler(commit_on_success=True)
def api_delete_fruit_v3(mb: Mybatis, fruit_id: int):
"""API 端点: 删除指定 ID 的水果"""
print_info(f"[FLASK_APP] Handling DELETE /api/v3/fruits/{fruit_id}")
affected_rows = mb.delete("flask_api_fruits.deleteFruitForApi", {'id': fruit_id})
if affected_rows > 0:
print_success(f"[FLASK_APP] Fruit ID {fruit_id} deleted, {affected_rows} row(s) affected.")
return jsonify(message=f"Fruit ID {fruit_id} deleted successfully.")
else:
print_warning(f"[FLASK_APP] Fruit ID {fruit_id} not found for deletion.")
return jsonify(error=f"Fruit ID {fruit_id} not found."), 404
# --- Flask 应用关闭时的清理 ---
@app.teardown_appcontext
def shutdown_flask_db_conn_final(exception=None):
"""在应用上下文销毁时关闭数据库连接"""
global _MB_INSTANCE_FLASK_APP
if _MB_INSTANCE_FLASK_APP and _MB_INSTANCE_FLASK_APP.conn:
print_info("[FLASK_APP] App context tearing down. Closing DB connection for Flask app.")
try:
_MB_INSTANCE_FLASK_APP.conn.close()
except Exception as e_close_final:
print_error(f"[FLASK_APP] Error closing DB connection: {e_close_final}")
finally:
_MB_INSTANCE_FLASK_APP.conn = None # 确保标记为已关闭
if __name__ == '__main__':
print_info("启动 Flask 开发服务器 (Mybatis-Py V_FINAL 示例)...")
# 确保 mappers 目录存在于 ex03_flask_integration.py 文件的同级,
# 并且其中包含 flask_fruits_ops.xml 文件。
app.run(debug=True, host='0.0.0.0', port=5005) # 使用新端口避免冲突
get_mybatis_for_request()
: 此函数尝试管理 Mybatis
实例。在生产级 Flask 应用中,强烈建议使用 Flask 的 g
对象结合 before_request
和 teardown_request
(或 teardown_appcontext
) 来管理每个请求的数据库连接和 Mybatis 实例,以避免全局状态和并发问题。@db_request_handler
: 此装饰器封装了获取 Mybatis
实例、执行视图函数、事务提交/回滚以及统一的错误处理逻辑。这是一种简化视图函数代码的常用模式。commit_on_success=True
在视图函数成功执行后提交事务。读操作(GET)则不需要。Mybatis
构造函数中的 mapper_dir_path
需要正确指向包含 flask_fruits_ops.xml
的目录。示例中假设它在当前 Flask 应用文件同级的 mappers
目录下。