在Python并发编程领域,多线程技术因其轻量级和易用性广受欢迎。然而全球解释器锁(GIL)的存在使得Python多线程在CPU密集型任务中表现特殊。本文将通过理论解析、代码实战和性能测试,带你全面掌握线程同步机制,深入理解GIL的工作机制,并提供绕过性能瓶颈的解决方案。
Python通过threading
模块提供线程操作支持,以下是两种经典创建方式:
import threading
class DownloadThread(threading.Thread):
def __init__(self, url):
super().__init__()
self.url = url
def run(self):
print(f"开始下载 {self.url}")
# 模拟下载耗时
time.sleep(2)
print(f"完成下载 {self.url}")
if __name__ == "__main__":
t1 = DownloadThread("https://example.com/file1.zip")
t2 = DownloadThread("https://example.com/file2.zip")
t1.start() # 启动线程
t2.start()
t1.join() # 等待线程结束
t2.join()
def download_task(url):
print(f"开始下载 {url}")
time.sleep(2)
print(f"完成下载 {url}")
t = threading.Thread(target=download_task, args=("https://example.com/file3.zip",))
t.start()
t.join()
关键点说明:
start()
方法触发线程执行,而非直接调用run()
join()
用于阻塞主线程直至子线程完成daemon
属性控制线程是否随主线程退出全局解释器锁(Global Interpreter Lock)是CPython解释器的核心机制,表现为:
def countdown(n):
while n > 0:
n -= 1
# 单线程执行
start = time.time()
countdown(100000000)
print(f"单线程耗时: {time.time() - start:.2f}s")
# 多线程执行
t1 = threading.Thread(target=countdown, args=(50000000,))
t2 = threading.Thread(target=countdown, args=(50000000,))
start = time.time()
t1.start(); t2.start()
t1.join(); t2.join()
print(f"双线程耗时: {time.time() - start:.2f}s")
典型输出结果:
单线程耗时: 3.12s
双线程耗时: 3.25s # 多线程反而更慢!
方案 | 适用场景 | 实现方式 |
---|---|---|
多进程 | CPU密集型 | multiprocessing模块 |
C扩展 | 关键代码优化 | Cython/Numba |
异步IO | I/O密集型 | asyncio库 |
Jython实现 | 全场景 | 使用无GIL的解释器 |
from concurrent.futures import ThreadPoolExecutor
import requests
def download_page(url):
resp = requests.get(url)
return len(resp.content)
urls = ["https://www.baidu.com"] * 10
with ThreadPoolExecutor(max_workers=4) as executor:
# 提交任务
futures = [executor.submit(download_page, url) for url in urls]
# 获取结果
results = [f.result() for f in futures]
print(f"下载总字节数: {sum(results)}")
特性说明:
map()
方法简化批量任务import threading
import time
class Account:
def __init__(self):
self.balance = 0
self.lock = threading.Lock()
def deposit(self, amount):
with self.lock: # 自动获取和释放锁
new_balance = self.balance + amount
time.sleep(0.001) # 增加竞争概率
self.balance = new_balance
account = Account()
threads = []
for _ in range(100):
t = threading.Thread(target=account.deposit, args=(1,))
threads.append(t)
t.start()
for t in threads:
t.join()
print(f"最终余额: {account.balance}") # 正确应为100
import threading
import time
# 定义一个下载调度器类
class DownloadScheduler:
def __init__(self):
# 初始化一个 threading.Event 对象
# Event 对象可以用于线程间的通信,它有一个内部标志,默认为 False
self.event = threading.Event()
def prepare_data(self):
# 打印提示信息,表示开始准备数据
print("准备数据...")
# 模拟准备数据的耗时操作,暂停 2 秒
time.sleep(2)
# 调用 set 方法将 Event 对象的内部标志设置为 True
# 这会通知所有等待该事件的线程可以继续执行
self.event.set()
def start_download(self):
# 调用 wait 方法,线程会阻塞在这里,直到 Event 对象的内部标志变为 True
# 也就是等待 prepare_data 方法调用 set 方法
self.event.wait()
# 当事件被触发后,打印提示信息,表示开始下载
print("开始下载...")
# 创建 DownloadScheduler 类的一个实例
scheduler = DownloadScheduler()
# 创建一个新线程,目标函数为 scheduler.prepare_data,并启动该线程
# 这个线程负责准备数据
threading.Thread(target=scheduler.prepare_data).start()
# 创建另一个新线程,目标函数为 scheduler.start_download,并启动该线程
# 这个线程会等待数据准备好后开始下载
threading.Thread(target=scheduler.start_download).start()
生产者-消费者模型实现:
import threading
# 定义消息队列类
class MessageQueue:
def __init__(self):
# 初始化一个空列表用于存储消息
self.queue = []
# 初始化一个条件变量,用于线程间的同步
self.cond = threading.Condition()
def put(self, msg):
# 使用条件变量的上下文管理器,自动获取锁
with self.cond:
# 将消息添加到队列中
self.queue.append(msg)
# 唤醒一个等待在该条件变量上的线程
self.cond.notify()
def get(self):
# 使用条件变量的上下文管理器,自动获取锁
with self.cond:
# 当队列中没有消息时,线程进入等待状态
while not self.queue:
# 自动释放锁并等待其他线程唤醒
self.cond.wait()
# 从队列头部取出并返回消息
return self.queue.pop(0)
# 生产者函数,用于向消息队列中添加消息
def producer(queue):
for i in range(5):
# 模拟生产消息
message = f"Message {i}"
print(f"Producing {message}")
# 将消息放入队列
queue.put(message)
# 模拟生产耗时
threading.Event().wait(1)
# 消费者函数,用于从消息队列中取出消息
def consumer(queue):
for i in range(5):
# 从队列中获取消息
message = queue.get()
print(f"Consuming {message}")
if __name__ == "__main__":
# 创建消息队列实例
queue = MessageQueue()
# 创建生产者线程
producer_thread = threading.Thread(target=producer, args=(queue,))
# 创建消费者线程
consumer_thread = threading.Thread(target=consumer, args=(queue,))
# 启动生产者线程
producer_thread.start()
# 启动消费者线程
consumer_thread.start()
# 等待生产者线程执行完毕
producer_thread.join()
# 等待消费者线程执行完毕
consumer_thread.join()
print("All tasks are done.")
import time
from concurrent.futures import ThreadPoolExecutor
# 定义一个模拟 I/O 任务的函数
def test_io_task():
# 模拟数据库查询操作,让线程休眠 0.1 秒
time.sleep(0.1)
# 定义一个单线程执行任务的函数
def run_single_thread():
# 记录开始时间
start = time.time()
# 循环执行 100 次模拟 I/O 任务
for _ in range(100):
test_io_task()
# 记录结束时间,并计算耗时,保留两位小数输出
print(f"单线程耗时: {time.time() - start:.2f}s")
# 定义一个使用线程池执行任务的函数
def run_multi_thread():
# 记录开始时间
start = time.time()
# 创建一个最大线程数为 20 的线程池,并使用上下文管理器管理其生命周期
with ThreadPoolExecutor(20) as executor:
# 利用线程池中的线程并发执行 100 次模拟 I/O 任务
executor.map(test_io_task, range(100))
# 记录结束时间,并计算耗时,保留两位小数输出
print(f"20线程池耗时: {time.time() - start:.2f}s")
# 调用单线程执行任务的函数
run_single_thread() # 约10.2秒
# 调用使用线程池执行任务的函数
run_multi_thread() # 约0.6秒
from multiprocessing import Pool
import os
import time
from concurrent.futures import ThreadPoolExecutor
# CPU密集型任务函数,计算 0 到 n-1 的平方和
def cpu_bound(n):
return sum(i * i for i in range(n))
# I/O密集型任务函数,模拟 I/O 操作(睡眠 0.1 秒),并返回当前进程的 ID
def io_bound(url):
time.sleep(0.1)
return os.getpid()
if __name__ == '__main__':
# CPU密集型使用进程池
with Pool(4) as p:
# 使用进程池并行执行 cpu_bound 函数,参数为 [10**6]*4,即 4 个 10**6
# p.map 会将 cpu_bound 函数应用到列表的每个元素上,并返回结果列表
print(p.map(cpu_bound, [10**6]*4))
# I/O密集型使用线程池
with ThreadPoolExecutor(10) as executor:
# 使用线程池并行执行 io_bound 函数,参数为 ["url"]*10,即 10 个 "url"
# executor.map 会将 io_bound 函数应用到列表的每个元素上,并返回结果迭代器,转换为列表输出
print(list(executor.map(io_bound, ["url"]*10)))
lockA = threading.Lock()
lockB = threading.Lock()
def thread1():
with lockA:
time.sleep(1)
with lockB: # 此处将阻塞
print("Thread1完成")
def thread2():
with lockB:
time.sleep(1)
with lockA: # 此处将阻塞
print("Thread2完成")
# 启动两个线程观察现象
挑战: 如何修改代码避免死锁?
contextlib
中的 ExitStack
一次性获取多个锁要求:
import threading
import time
# 缓冲区大小
BUFFER_SIZE = 5
# 缓冲区
buffer = []
# 条件变量
condition = threading.Condition()
# 生产者函数
def producer(id):
global buffer
while True:
with condition:
# 当缓冲区满时,生产者等待
while len(buffer) == BUFFER_SIZE:
print(f"生产者 {id} 发现缓冲区已满,等待...")
condition.wait()
# 生产一个数据
item = f"Item-{id}"
buffer.append(item)
print(f"生产者 {id} 生产了 {item},当前缓冲区: {buffer}")
# 通知可能正在等待的消费者
condition.notify_all()
# 模拟生产耗时
time.sleep(1)
# 消费者函数
def consumer(id):
global buffer
while True:
with condition:
# 当缓冲区为空时,消费者等待
while len(buffer) == 0:
print(f"消费者 {id} 发现缓冲区为空,等待...")
condition.wait()
# 消费一个数据
item = buffer.pop(0)
print(f"消费者 {id} 消费了 {item},当前缓冲区: {buffer}")
# 通知可能正在等待的生产者
condition.notify_all()
# 模拟消费耗时
time.sleep(1)
if __name__ == "__main__":
# 创建生产者线程
producers = [threading.Thread(target=producer, args=(i,)) for i in range(2)]
# 创建消费者线程
consumers = [threading.Thread(target=consumer, args=(i,)) for i in range(2)]
# 启动生产者线程
for p in producers:
p.start()
# 启动消费者线程
for c in consumers:
c.start()
# 等待所有线程结束(这里实际上不会结束,因为是无限循环)
for p in producers:
p.join()
for c in consumers:
c.join()
设计实验对比以下场景:
import time
import threading
import multiprocessing
# 判断一个数是否为素数
def is_prime(n):
if n < 2:
return False
for i in range(2, int(n**0.5) + 1):
if n % i == 0:
return False
return True
# 单线程计算素数
def single_threaded(n):
primes = []
for i in range(n):
if is_prime(i):
primes.append(i)
return primes
# 多线程计算素数
def multi_threaded(n, num_threads):
def worker(start, end, result):
local_primes = []
for i in range(start, end):
if is_prime(i):
local_primes.append(i)
result.extend(local_primes)
chunk_size = n // num_threads
threads = []
results = [[] for _ in range(num_threads)]
for i in range(num_threads):
start = i * chunk_size
end = start + chunk_size if i < num_threads - 1 else n
t = threading.Thread(target=worker, args=(start, end, results[i]))
threads.append(t)
t.start()
for t in threads:
t.join()
primes = []
for res in results:
primes.extend(res)
return primes
# 多进程计算素数
def multi_processed(n, num_processes):
def worker(start, end, queue):
local_primes = []
for i in range(start, end):
if is_prime(i):
local_primes.append(i)
queue.put(local_primes)
chunk_size = n // num_processes
processes = []
queue = multiprocessing.Queue()
for i in range(num_processes):
start = i * chunk_size
end = start + chunk_size if i < num_processes - 1 else n
p = multiprocessing.Process(target=worker, args=(start, end, queue))
processes.append(p)
p.start()
for p in processes:
p.join()
primes = []
while not queue.empty():
primes.extend(queue.get())
return primes
if __name__ == "__main__":
n = 100000
num_threads = 4
num_processes = 4
# 单线程
start_time = time.time()
single_threaded(n)
single_time = time.time() - start_time
print(f"单线程耗时: {single_time:.4f} 秒")
# 多线程
start_time = time.time()
multi_threaded(n, num_threads)
multi_thread_time = time.time() - start_time
print(f"多线程({num_threads} 线程)耗时: {multi_thread_time:.4f} 秒")
# 多进程
start_time = time.time()
multi_processed(n, num_processes)
multi_process_time = time.time() - start_time
print(f"多进程({num_processes} 进程)耗时: {multi_process_time:.4f} 秒")
理解GIL机制是掌握Python并发的关键。对于I/O密集型任务,多线程仍然是高效选择;而CPU密集型任务应考虑多进程或混合编程。合理使用线程同步工具和线程池,结合asyncio等异步方案,才能最大化发挥Python的并发潜力。
学习路线建议:
multiprocessing
模块asyncio
异步编程模型