欢迎回来,我的AI工程师朋友们! 经过第一部分的学习,我们已经搭建了一个基础的端到端对话系统。但是,就像一辆刚下生产线的汽车还需要精心调校才能上路一样,我们的对话系统也需要进一步优化才能在真实场景中发挥最大价值!
今天我们要做的事情更加激动人心——我们要让系统变得更快、更稳定、更智能!想象一下,当用户和你的AI助手对话时,感受不到任何延迟,就像和真人聊天一样自然流畅。这就是我们今天要实现的目标!
在这个技术深水区的探索中,我们将掌握流式推理、并发优化、生产部署等企业级技能,让你的对话系统从"能跑"升级到"跑得快、跑得稳、跑得久"!
首先,让我们来理解什么是流式推理。传统的批处理模式就像工厂流水线,必须等所有零件准备好才能开始组装;而流式推理则像一个高效的厨师,一边准备食材,一边开始烹饪,大大减少了等待时间!
在构建高性能的流式系统之前,我们需要深入了解各种技术方案的优劣势。这就像选择合适的工具一样,不同的场景需要不同的解决方案。
技术方案 | 实时性 | 准确率 | 资源占用 | 实现复杂度 | 适用场景 |
---|---|---|---|---|---|
Whisper实时流式 | 中等 | 95%+ | 高 | 中等 | 高质量要求 |
Wav2Vec2流式 | 高 | 90%+ | 中等 | 高 | 实时交互 |
在线ASR API | 高 | 98%+ | 低 | 低 | 网络稳定环境 |
自定义轻量模型 | 极高 | 85%+ | 低 | 极高 | 边缘设备 |
优化策略 | 延迟改善 | 实现难度 | 资源需求 | 效果稳定性 |
---|---|---|---|---|
KV-Cache优化 | 30-50% | 中等 | 内存+20% | 高 |
动态批处理 | 40-60% | 高 | GPU利用率+30% | 高 |
模型量化 | 20-40% | 低 | 内存-50% | 中等 |
流水线并行 | 50-70% | 极高 | 多GPU | 高 |
投机解码 | 60-80% | 高 | 内存+40% | 中等 |
技术方案 | 首字节延迟 | 音质 | 资源占用 | 适用场景 |
---|---|---|---|---|
FastSpeech2流式 | <200ms | 优秀 | 中等 | 生产环境 |
VITS流式 | <500ms | 极优 | 高 | 高质量场景 |
神经声码器 | <100ms | 良好 | 低 | 实时应用 |
混合合成 | <300ms | 优秀 | 中等 | 平衡方案 |
并发模式 | 吞吐量提升 | 延迟影响 | 复杂度 | 错误处理 |
---|---|---|---|---|
多线程 | 2-4x | 略增加 | 中等 | 复杂 |
异步IO | 5-10x | 显著降低 | 高 | 复杂 |
进程池 | 线性扩展 | 显著增加 | 低 | 简单 |
协程 | 10-50x | 最低 | 中等 | 中等 |
优化目标 | 目标值 | 当前值 | 改善空间 | 关键技术 |
---|---|---|---|---|
首字节延迟 | <500ms | 1-3s | 60-80% | VAD+流式ASR |
端到端延迟 | <2s | 3-8s | 50-70% | 全链路流式 |
并发处理 | 100+ | 10-20 | 5-10x | 异步架构 |
内存效率 | <4GB | 8-16GB | 50-75% | 模型量化 |
GPU利用率 | >80% | 30-50% | 60-160% | 批处理优化 |
部署方式 | 扩展性 | 稳定性 | 运维复杂度 | 成本 |
---|---|---|---|---|
单机部署 | 低 | 中等 | 低 | 低 |
容器化部署 | 中等 | 高 | 中等 | 中等 |
微服务架构 | 高 | 高 | 高 | 高 |
Serverless | 极高 | 中等 | 低 | 弹性 |
应用场景 | 延迟要求 | 并发要求 | 可靠性要求 | 推荐方案 |
---|---|---|---|---|
演示原型 | <5s | 1-10人 | 一般 | 单机+基础优化 |
企业内部 | <3s | 50-100人 | 高 | 容器+流式优化 |
商业产品 | <2s | 1000+人 | 极高 | 微服务+全优化 |
实时客服 | <1s | 10000+人 | 极高 | 专用硬件+定制优化 |
让我们开始实现真正的流式ASR处理器!这个模块就像一个敏感的"耳朵",能够实时捕捉和处理语音信号,并且支持VAD(语音活动检测)来智能识别何时开始和结束语音。
# streaming_asr.py - 流式ASR处理器
import torch
import numpy as np
import threading
import queue
import time
import logging
from typing import Optional, Callable, Dict, Any, List
from dataclasses import dataclass
import webrtcvad
import librosa
import soundfile as sf
from collections import deque
import asyncio
# 导入基础ASR模块
from asr_module import ASREngine
logger = logging.getLogger(__name__)
@dataclass
class VADConfig:
"""VAD配置参数"""
aggressiveness: int = 3 # VAD敏感度 (0-3)
frame_duration: int = 30 # 帧长度 (10, 20, 30 ms)
sample_rate: int = 16000
silence_threshold: int = 20 # 连续静音帧数阈值
speech_threshold: int = 5 # 连续语音帧数阈值
@dataclass
class StreamingChunk:
"""流式音频块"""
audio_data: np.ndarray
timestamp: float
is_speech: bool
chunk_id: int
metadata: Optional[Dict[str, Any]] = None
class StreamingVAD:
"""流式语音活动检测"""
def __init__(self, config: VADConfig):
self.config = config
self.vad = webrtcvad.Vad(config.aggressiveness)
# 状态管理
self.is_speaking = False
self.silence_frames = 0
self.speech_frames = 0
# 音频缓冲
self.audio_buffer = deque(maxlen=100)
self.speech_buffer = []
logger.info(f" VAD初始化完成,敏感度: {config.aggressiveness}")
def process_frame(self, audio_frame: bytes) -> Dict[str, Any]:
"""
处理音频帧
Args:
audio_frame: 原始音频帧数据
Returns:
处理结果字典
"""
# VAD检测
is_speech = self.vad.is_speech(audio_frame, self.config.sample_rate)
result = {
"is_speech": is_speech,
"is_speaking": self.is_speaking,
"state_change": False,
"audio_data": None
}
if is_speech:
self.speech_frames += 1
self.silence_frames = 0
# 添加到语音缓冲区
self.speech_buffer.append(audio_frame)
# 检查是否开始说话
if not self.is_speaking and self.speech_frames >= self.config.speech_threshold:
self.is_speaking = True
result["state_change"] = True
result["event"] = "speech_start"
logger.debug("️ 检测到语音开始")
else:
self.silence_frames += 1
self.speech_frames = 0
# 如果正在说话,继续添加到缓冲区(可能是短暂停顿)
if self.is_speaking:
self.speech_buffer.append(audio_frame)
# 检查是否停止说话
if self.is_speaking and self.silence_frames >= self.config.silence_threshold:
self.is_speaking = False
result["state_change"] = True
result["event"] = "speech_end"
# 返回完整的语音数据
if self.speech_buffer:
result["audio_data"] = b"".join(self.speech_buffer)
self.speech_buffer.clear()
logger.debug(" 检测到语音结束")
return result
class StreamingASRProcessor:
"""流式ASR处理器"""
def __init__(
self,
asr_engine: ASREngine,
vad_config: Optional[VADConfig] = None,
chunk_duration: float = 0.5,
overlap_duration: float = 0.1,
max_chunk_size: int = 16000 * 10 # 10秒最大块大小
):
"""
初始化流式ASR处理器
Args:
asr_engine: ASR引擎实例
vad_config: VAD配置
chunk_duration: 音频块持续时间(秒)
overlap_duration: 重叠时间(秒)
max_chunk_size: 最大块大小
"""
self.asr_engine = asr_engine
self.vad_config = vad_config or VADConfig()
self.chunk_duration = chunk_duration
self.overlap_duration = overlap_duration
self.max_chunk_size = max_chunk_size
# 初始化VAD
self.vad = StreamingVAD(self.vad_config)
# 流式处理状态
self.is_processing = False
self.audio_queue = queue.Queue()
self.result_queue = queue.Queue()
# 音频缓冲和上下文
self.audio_buffer = deque()
self.context_buffer = deque(maxlen=5) # 保留上下文
# 处理线程
self.processing_thread = None
# 回调函数
self.callbacks = {
"on_partial_result": [],
"on_final_result": [],
"on_speech_start": [],
"on_speech_end": [],
"on_error": []
}
# 性能统计
self.stats = {
"total_chunks": 0,
"processing_times": [],
"recognition_times": [],
"queue_sizes": []
}
logger.info(" 流式ASR处理器初始化完成")
def add_callback(self, event: str, callback: Callable):
"""添加回调函数"""
if event in self.callbacks:
self.callbacks[event].append(callback)
def _trigger_callbacks(self, event: str, data: Any = None):
"""触发回调函数"""
for callback in self.callbacks.get(event, []):
try:
callback(data)
except Exception as e:
logger.error(f"❌ 回调函数执行失败 ({event}): {e}")
def start_streaming(self):
"""开始流式处理"""
if self.is_processing:
logger.warning("⚠️ 流式处理已在运行")
return
self.is_processing = True
self.processing_thread = threading.Thread(
target=self._processing_loop,
daemon=True
)
self.processing_thread.start()
logger.info(" 流式ASR处理已启动")
def stop_streaming(self):
"""停止流式处理"""
if not self.is_processing:
return
self.is_processing = False
if self.processing_thread:
self.processing_thread.join(timeout=2)
logger.info(" 流式ASR处理已停止")
def add_audio_chunk(self, audio_data: np.ndarray, timestamp: Optional[float] = None):
"""
添加音频块到处理队列
Args:
audio_data: 音频数据
timestamp: 时间戳
"""
if not self.is_processing:
logger.warning("⚠️ 流式处理未启动")
return
chunk = StreamingChunk(
audio_data=audio_data,
timestamp=timestamp or time.time(),
is_speech=False,
chunk_id=self.stats["total_chunks"]
)
self.audio_queue.put(chunk)
self.stats["total_chunks"] += 1
self.stats["queue_sizes"].append(self.audio_queue.qsize())
def _processing_loop(self):
"""主处理循环"""
logger.info(" 开始流式处理循环")
while self.is_processing:
try:
# 获取音频块
chunk = self.audio_queue.get(timeout=0.1)
# 处理音频块
self._process_audio_chunk(chunk)
except queue.Empty:
continue
except Exception as e:
logger.error(f"❌ 处理循环错误: {e}")
self._trigger_callbacks("on_error", {"error": str(e)})
def _process_audio_chunk(self, chunk: StreamingChunk):
"""处理单个音频块"""
start_time = time.time()
try:
# 转换音频格式用于VAD
audio_int16 = (chunk.audio_data * 32767).astype(np.int16)
audio_bytes = audio_int16.tobytes()
# VAD检测
vad_result = self.vad.process_frame(audio_bytes)
chunk.is_speech = vad_result["is_speech"]
# 处理状态变化
if vad_result["state_change"]:
if vad_result["event"] == "speech_start":
self._trigger_callbacks("on_speech_start", chunk)
elif vad_result["event"] == "speech_end":
# 处理完整语音段
if vad_result["audio_data"]:
self._process_speech_segment(vad_result["audio_data"], chunk.timestamp)
self._trigger_callbacks("on_speech_end", chunk)
# 如果正在说话,累积音频数据
if self.vad.is_speaking:
self.audio_buffer.append(chunk.audio_data)
# 检查是否需要进行增量识别
if len(self.audio_buffer) >= int(self.chunk_duration * 16000 / 1024):
self._perform_incremental_recognition()
# 记录处理时间
processing_time = time.time() - start_time
self.stats["processing_times"].append(processing_time)
except Exception as e:
logger.error(f"❌ 音频块处理失败: {e}")
self._trigger_callbacks("on_error", {"error": str(e), "chunk_id": chunk.chunk_id})
def _process_speech_segment(self, audio_bytes: bytes, timestamp: float):
"""处理完整的语音段"""
try:
# 转换音频数据
audio_array = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32767.0
# 进行ASR识别
recognition_start = time.time()
result = self.asr_engine.transcribe_audio_data(
audio_array,
sample_rate=self.vad_config.sample_rate
)
recognition_time = time.time() - recognition_start
self.stats["recognition_times"].append(recognition_time)
# 触发最终结果回调
final_result = {
"text": result.get("text", ""),
"confidence": result.get("confidence", 0.0),
"timestamp": timestamp,
"recognition_time": recognition_time,
"type": "final"
}
self._trigger_callbacks("on_final_result", final_result)
logger.info(f"✅ 语音段识别完成: {final_result['text'][:50]}...")
except Exception as e:
logger.error(f"❌ 语音段处理失败: {e}")
self._trigger_callbacks("on_error", {"error": str(e)})
def _perform_incremental_recognition(self):
"""执行增量识别"""
if not self.audio_buffer:
return
try:
# 合并音频数据
combined_audio = np.concatenate(list(self.audio_buffer))
# 限制最大长度
if len(combined_audio) > self.max_chunk_size:
combined_audio = combined_audio[-self.max_chunk_size:]
# 进行增量识别
recognition_start = time.time()
result = self.asr_engine.transcribe_audio_data(
combined_audio,
sample_rate=self.vad_config.sample_rate
)
recognition_time = time.time() - recognition_start
# 触发部分结果回调
partial_result = {
"text": result.get("text", ""),
"confidence": result.get("confidence", 0.0),
"timestamp": time.time(),
"recognition_time": recognition_time,
"type": "partial"
}
self._trigger_callbacks("on_partial_result", partial_result)
# 清理旧的音频数据,保留一些重叠
overlap_samples = int(self.overlap_duration * 16000)
if len(combined_audio) > overlap_samples:
self.audio_buffer.clear()
self.audio_buffer.append(combined_audio[-overlap_samples:])
except Exception as e:
logger.error(f"❌ 增量识别失败: {e}")
def get_statistics(self) -> Dict[str, Any]:
"""获取处理统计信息"""
stats = self.stats.copy()
if stats["processing_times"]:
stats["avg_processing_time"] = np.mean(stats["processing_times"])
stats["max_processing_time"] = np.max(stats["processing_times"])
if stats["recognition_times"]:
stats["avg_recognition_time"] = np.mean(stats["recognition_times"])
stats["max_recognition_time"] = np.max(stats["recognition_times"])
if stats["queue_sizes"]:
stats["avg_queue_size"] = np.mean(stats["queue_sizes"])
stats["max_queue_size"] = np.max(stats["queue_sizes"])
stats["is_processing"] = self.is_processing
stats["vad_is_speaking"] = self.vad.is_speaking
stats["buffer_size"] = len(self.audio_buffer)
return stats
class AsyncStreamingASR:
"""异步流式ASR处理器"""
def __init__(self, asr_processor: StreamingASRProcessor):
self.asr_processor = asr_processor
self.loop = None
self.tasks = set()
async def start_async_processing(self):
"""启动异步处理"""
self.loop = asyncio.get_event_loop()
self.asr_processor.start_streaming()
# 启动监控任务
monitor_task = self.loop.create_task(self._monitor_performance())
self.tasks.add(monitor_task)
logger.info(" 异步流式ASR处理已启动")
async def _monitor_performance(self):
"""性能监控任务"""
while self.asr_processor.is_processing:
try:
stats = self.asr_processor.get_statistics()
# 检查性能指标
if stats.get("avg_queue_size", 0) > 10:
logger.warning(f"⚠️ 音频队列积压: {stats['avg_queue_size']:.1f}")
if stats.get("avg_processing_time", 0) > 0.1:
logger.warning(f"⚠️ 处理延迟过高: {stats['avg_processing_time']:.3f}s")
await asyncio.sleep(5) # 每5秒检查一次
except Exception as e:
logger.error(f"❌ 性能监控错误: {e}")
await asyncio.sleep(1)
async def stop_async_processing(self):
"""停止异步处理"""
self.asr_processor.stop_streaming()
# 取消所有任务
for task in self.tasks:
task.cancel()
# 等待任务完成
if self.tasks:
await asyncio.gather(*self.tasks, return_exceptions=True)
logger.info(" 异步流式ASR处理已停止")
# 使用示例和测试代码
if __name__ == "__main__":
print(" 流式ASR处理器测试")
print("=" * 50)
# 创建ASR引擎
from asr_module import ASREngine
asr_engine = ASREngine(model_name="base")
# 创建流式处理器
streaming_asr = StreamingASRProcessor(asr_engine)
# 添加回调函数
def on_partial_result(result):
print(f" 部分结果: {result['text']}")
def on_final_result(result):
print(f"✅ 最终结果: {result['text']}")
print(f" 置信度: {result['confidence']:.2f}")
print(f" 识别耗时: {result['recognition_time']:.3f}s")
def on_speech_start(chunk):
print("️ 开始说话")
def on_speech_end(chunk):
print(" 说话结束")
streaming_asr.add_callback("on_partial_result", on_partial_result)
streaming_asr.add_callback("on_final_result", on_final_result)
streaming_asr.add_callback("on_speech_start", on_speech_start)
streaming_asr.add_callback("on_speech_end", on_speech_end)
# 启动流式处理
streaming_asr.start_streaming()
# 模拟音频输入(实际使用中会从麦克风获取)
print(" 开始模拟音频输入...")
try:
# 生成一些测试音频数据
import numpy as np
duration = 5.0 # 5秒测试音频
sample_rate = 16000
# 生成带有语音模式的测试信号
t = np.linspace(0, duration, int(sample_rate * duration))
# 模拟语音: 前1秒静音,中间3秒有声音,后1秒静音
audio_signal = np.zeros_like(t)
speech_start = int(sample_rate * 1.0)
speech_end = int(sample_rate * 4.0)
# 添加语音信号(简单的正弦波组合)
audio_signal[speech_start:speech_end] = (
0.3 * np.sin(2 * np.pi * 440 * t[speech_start:speech_end]) +
0.2 * np.sin(2 * np.pi * 880 * t[speech_start:speech_end])
)
# 分块发送音频数据
chunk_size = int(sample_rate * 0.1) # 100ms块
for i in range(0, len(audio_signal), chunk_size):
chunk = audio_signal[i:i+chunk_size]
streaming_asr.add_audio_chunk(chunk, time.time())
time.sleep(0.05) # 模拟实时输入
# 等待处理完成
time.sleep(2)
# 显示统计信息
stats = streaming_asr.get_statistics()
print(f"\n 处理统计:")
print(f" 总音频块: {stats['total_chunks']}")
print(f" 平均处理时间: {stats.get('avg_processing_time', 0):.3f}s")
print(f" 平均识别时间: {stats.get('avg_recognition_time', 0):.3f}s")
print(f" 平均队列大小: {stats.get('avg_queue_size', 0):.1f}")
except KeyboardInterrupt:
print("\n 用户中断测试")
except Exception as e:
print(f"❌ 测试失败: {e}")
finally:
streaming_asr.stop_streaming()
print(" 流式ASR处理器测试完成!")
现在让我们实现流式LLM推理引擎!这是整个系统的"大脑中枢",需要能够在接收到部分输入时就开始思考,并且能够实时输出推理结果。就像一个聪明的对话者,能够边听边思考,边思考边回答!
# streaming_llm.py - 流式LLM推理引擎
import torch
import asyncio
import threading
import queue
import time
import logging
from typing import Dict, List, Any, Optional, Generator, Callable, Tuple
from dataclasses import dataclass, field
import numpy as np
from collections import deque
import json
import gc
# 导入transformers库
try:
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
GenerationConfig, TextIteratorStreamer,
StoppingCriteria, StoppingCriteriaList
)
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
# 导入基础模块
from llm_module import LLMManager, ChatMessage, GenerationParams
logger = logging.getLogger(__name__)
@dataclass
class StreamingConfig:
"""流式推理配置"""
# 生成参数
max_new_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.9
top_k: int = 50
do_sample: bool = True
repetition_penalty: float = 1.1
# 流式参数
stream_chunk_size: int = 1 # 每次输出的token数
buffer_size: int = 10 # 输出缓冲区大小
timeout: float = 30.0 # 生成超时时间
# 优化参数
use_cache: bool = True
use_flash_attention: bool = False
use_kv_cache_optimization: bool = True
# 批处理参数
max_batch_size: int = 4
dynamic_batching: bool = True
batch_timeout: float = 0.1
@dataclass
class StreamingRequest:
"""流式推理请求"""
request_id: str
messages: List[ChatMessage]
config: StreamingConfig
timestamp: float = field(default_factory=time.time)
callback: Optional[Callable] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class StreamingResponse:
"""流式推理响应"""
request_id: str
token: str
is_final: bool
timestamp: float
position: int
total_tokens: Optional[int] = None
latency: Optional[float] = None
metadata: Dict[str, Any] = field(default_factory=dict)
class KVCacheManager:
"""KV Cache管理器"""
def __init__(self, max_cache_size: int = 1000):
self.max_cache_size = max_cache_size
self.cache = {}
self.usage_times = {}
self.cache_hits = 0
self.cache_misses = 0
def get_cache_key(self, input_ids: torch.Tensor) -> str:
"""生成缓存键"""
return hash(tuple(input_ids.flatten().tolist()))
def get(self, cache_key: str) -> Optional[Tuple]:
"""获取缓存"""
if cache_key in self.cache:
self.usage_times[cache_key] = time.time()
self.cache_hits += 1
return self.cache[cache_key]
self.cache_misses += 1
return None
def put(self, cache_key: str, past_key_values: Tuple):
"""存储缓存"""
# 如果缓存已满,移除最久未使用的
if len(self.cache) >= self.max_cache_size:
oldest_key = min(self.usage_times.keys(),
key=lambda k: self.usage_times[k])
del self.cache[oldest_key]
del self.usage_times[oldest_key]
self.cache[cache_key] = past_key_values
self.usage_times[cache_key] = time.time()
def get_stats(self) -> Dict[str, Any]:
"""获取缓存统计"""
total_requests = self.cache_hits + self.cache_misses
hit_rate = self.cache_hits / max(1, total_requests)
return {
"cache_size": len(self.cache),
"max_cache_size": self.max_cache_size,
"cache_hits": self.cache_hits,
"cache_misses": self.cache_misses,
"hit_rate": hit_rate,
"memory_usage_mb": self._estimate_memory_usage()
}
def _estimate_memory_usage(self) -> float:
"""估算内存使用量"""
total_size = 0
for cache_value in self.cache.values():
if isinstance(cache_value, tuple):
for tensor in cache_value:
if isinstance(tensor, torch.Tensor):
total_size += tensor.numel() * tensor.element_size()
return total_size / (1024 * 1024) # MB
class DynamicBatcher:
"""动态批处理器"""
def __init__(self, max_batch_size: int = 4, timeout: float = 0.1):
self.max_batch_size = max_batch_size
self.timeout = timeout
self.pending_requests = []
self.batch_queue = queue.Queue()
self.stats = {
"total_batches": 0,
"avg_batch_size": 0,
"batch_sizes": []
}
def add_request(self, request: StreamingRequest):
"""添加请求到批处理队列"""
self.pending_requests.append(request)
# 检查是否可以组成批次
if len(self.pending_requests) >= self.max_batch_size:
self._create_batch()
def _create_batch(self):
"""创建批次"""
if not self.pending_requests:
return
batch_size = min(len(self.pending_requests), self.max_batch_size)
batch = self.pending_requests[:batch_size]
self.pending_requests = self.pending_requests[batch_size:]
self.batch_queue.put(batch)
# 更新统计
self.stats["total_batches"] += 1
self.stats["batch_sizes"].append(batch_size)
if self.stats["batch_sizes"]:
self.stats["avg_batch_size"] = np.mean(self.stats["batch_sizes"])
def get_batch(self, timeout: Optional[float] = None) -> Optional[List[StreamingRequest]]:
"""获取批次"""
timeout = timeout or self.timeout
try:
return self.batch_queue.get(timeout=timeout)
except queue.Empty:
# 超时后处理剩余请求
if self.pending_requests:
self._create_batch()
try:
return self.batch_queue.get_nowait()
except queue.Empty:
pass
return None
class StreamingLLMEngine:
"""流式LLM推理引擎"""
def __init__(
self,
model_name: str = "gpt2",
device: str = "auto",
config: Optional[StreamingConfig] = None,
cache_dir: str = "./models/llm"
):
"""
初始化流式LLM引擎
Args:
model_name: 模型名称
device: 设备类型
config: 流式配置
cache_dir: 缓存目录
"""
if not TRANSFORMERS_AVAILABLE:
raise ImportError("transformers库不可用")
self.model_name = model_name
self.config = config or StreamingConfig()
self.cache_dir = cache_dir
# 设备配置
if device == "auto":
if torch.cuda.is_available():
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
else:
self.device = device
# 加载模型
self._load_model()
# 初始化组件
self.kv_cache = KVCacheManager()
self.batcher = DynamicBatcher(
max_batch_size=self.config.max_batch_size,
timeout=self.config.batch_timeout
)
# 处理状态
self.is_running = False
self.processing_thread = None
# 请求管理
self.active_requests = {}
self.request_counter = 0
# 性能统计
self.stats = {
"total_requests": 0,
"total_tokens": 0,
"avg_latency": 0,
"latencies": deque(maxlen=1000),
"tokens_per_second": 0,
"token_rates": deque(maxlen=100)
}
logger.info(f" 流式LLM引擎初始化完成: {model_name} on {self.device}")
def _load_model(self):
"""加载模型和分词器"""
try:
logger.info(" 加载分词器...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
cache_dir=self.cache_dir,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info(" 加载模型...")
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
cache_dir=self.cache_dir,
trust_remote_code=True,
torch_dtype=torch.float16 if self.device != "cpu" else torch.float32
)
self.model = self.model.to(self.device)
self.model.eval()
logger.info("✅ 模型加载成功")
except Exception as e:
logger.error(f"❌ 模型加载失败: {e}")
raise
def start_engine(self):
"""启动推理引擎"""
if self.is_running:
logger.warning("⚠️ 推理引擎已在运行")
return
self.is_running = True
self.processing_thread = threading.Thread(
target=self._processing_loop,
daemon=True
)
self.processing_thread.start()
logger.info(" 流式LLM推理引擎已启动")
def stop_engine(self):
"""停止推理引擎"""
if not self.is_running:
return
self.is_running = False
if self.processing_thread:
self.processing_thread.join(timeout=5)
logger.info(" 流式LLM推理引擎已停止")
def generate_stream(
self,
messages: List[ChatMessage],
callback: Optional[Callable] = None,
**kwargs
) -> str:
"""
流式生成文本
Args:
messages: 消息列表
callback: 回调函数
**kwargs: 其他参数
Returns:
请求ID
"""
if not self.is_running:
raise RuntimeError("推理引擎未启动")
# 创建请求
request_id = f"req_{self.request_counter}_{int(time.time() * 1000)}"
self.request_counter += 1
# 合并配置
request_config = StreamingConfig(**{
**self.config.__dict__,
**kwargs
})
request = StreamingRequest(
request_id=request_id,
messages=messages,
config=request_config,
callback=callback
)
# 添加到批处理器
self.batcher.add_request(request)
self.active_requests[request_id] = request
logger.info(f" 添加生成请求: {request_id}")
return request_id
def _processing_loop(self):
"""主处理循环"""
logger.info(" 开始推理处理循环")
while self.is_running:
try:
# 获取批次
batch = self.batcher.get_batch(timeout=0.1)
if not batch:
continue
# 处理批次
self._process_batch(batch)
except Exception as e:
logger.error(f"❌ 处理循环错误: {e}")
time.sleep(0.1)
def _process_batch(self, batch: List[StreamingRequest]):
"""处理批次请求"""
logger.debug(f" 处理批次,大小: {len(batch)}")
# 对于简化,这里逐个处理(实际可以真正批处理)
for request in batch:
try:
self._process_single_request(request)
except Exception as e:
logger.error(f"❌ 请求处理失败 {request.request_id}: {e}")
self._send_error_response(request, str(e))
def _process_single_request(self, request: StreamingRequest):
"""处理单个请求"""
start_time = time.time()
try:
# 构建输入
input_text = self._format_messages(request.messages)
# 编码输入
inputs = self.tokenizer.encode(
input_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=4096 - request.config.max_new_tokens
)
if self.device != "cpu":
inputs = inputs.to(self.device)
# 检查KV缓存
cache_key = self.kv_cache.get_cache_key(inputs)
past_key_values = self.kv_cache.get(cache_key) if request.config.use_kv_cache_optimization else None
# 生成配置
generation_config = GenerationConfig(
max_new_tokens=request.config.max_new_tokens,
temperature=request.config.temperature,
top_p=request.config.top_p,
top_k=request.config.top_k,
do_sample=request.config.do_sample,
repetition_penalty=request.config.repetition_penalty,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
use_cache=request.config.use_cache
)
# 流式生成
self._generate_streaming(request, inputs, generation_config, past_key_values)
# 更新统计
total_time = time.time() - start_time
self.stats["latencies"].append(total_time)
self.stats["total_requests"] += 1
if self.stats["latencies"]:
self.stats["avg_latency"] = np.mean(list(self.stats["latencies"]))
except Exception as e:
logger.error(f"❌ 请求处理异常: {e}")
self._send_error_response(request, str(e))
finally:
# 清理请求
if request.request_id in self.active_requests:
del self.active_requests[request.request_id]
def _generate_streaming(
self,
request: StreamingRequest,
input_ids: torch.Tensor,
generation_config: GenerationConfig,
past_key_values: Optional[Tuple] = None
):
"""执行流式生成"""
position = 0
generated_tokens = 0
current_input = input_ids
with torch.no_grad():
while generated_tokens < request.config.max_new_tokens:
try:
# 模型推理
outputs = self.model(
current_input,
past_key_values=past_key_values,
use_cache=generation_config.use_cache
)
# 获取logits和新的past_key_values
logits = outputs.logits[:, -1, :]
if generation_config.use_cache:
past_key_values = outputs.past_key_values
# 应用生成参数
if generation_config.temperature != 1.0:
logits = logits / generation_config.temperature
if generation_config.top_k > 0:
top_k_logits, top_k_indices = torch.topk(logits, generation_config.top_k)
logits = torch.full_like(logits, float('-inf'))
logits.scatter_(1, top_k_indices, top_k_logits)
if generation_config.top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > generation_config.top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
# 采样下一个token
if generation_config.do_sample:
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
# 检查结束条件
if next_token.item() == self.tokenizer.eos_token_id:
break
# 解码token
token_text = self.tokenizer.decode(next_token[0], skip_special_tokens=True)
# 发送响应
response = StreamingResponse(
request_id=request.request_id,
token=token_text,
is_final=False,
timestamp=time.time(),
position=position,
latency=time.time() - request.timestamp
)
self._send_response(request, response)
# 更新状态
current_input = next_token
position += 1
generated_tokens += 1
# 更新统计
self.stats["total_tokens"] += 1
# 计算token生成速度
elapsed_time = time.time() - request.timestamp
if elapsed_time > 0:
token_rate = generated_tokens / elapsed_time
self.stats["token_rates"].append(token_rate)
if self.stats["token_rates"]:
self.stats["tokens_per_second"] = np.mean(list(self.stats["token_rates"]))
except Exception as e:
logger.error(f"❌ 生成过程异常: {e}")
break
# 发送最终响应
final_response = StreamingResponse(
request_id=request.request_id,
token="",
is_final=True,
timestamp=time.time(),
position=position,
total_tokens=generated_tokens,
latency=time.time() - request.timestamp
)
self._send_response(request, final_response)
# 更新KV缓存
if request.config.use_kv_cache_optimization and past_key_values:
cache_key = self.kv_cache.get_cache_key(input_ids)
self.kv_cache.put(cache_key, past_key_values)
def _format_messages(self, messages: List[ChatMessage]) -> str:
"""格式化消息"""
formatted_parts = []
for message in messages:
if message.role == "system":
formatted_parts.append(f"System: {message.content}")
elif message.role == "user":
formatted_parts.append(f"User: {message.content}")
elif message.role == "assistant":
formatted_parts.append(f"Assistant: {message.content}")
formatted_parts.append("Assistant:")
return "\n".join(formatted_parts)
def _send_response(self, request: StreamingRequest, response: StreamingResponse):
"""发送响应"""
if request.callback:
try:
request.callback(response)
except Exception as e:
logger.error(f"❌ 回调函数执行失败: {e}")
def _send_error_response(self, request: StreamingRequest, error: str):
"""发送错误响应"""
error_response = StreamingResponse(
request_id=request.request_id,
token="",
is_final=True,
timestamp=time.time(),
position=0,
metadata={"error": error}
)
self._send_response(request, error_response)
def get_statistics(self) -> Dict[str, Any]:
"""获取统计信息"""
stats = self.stats.copy()
stats["active_requests"] = len(self.active_requests)
stats["cache_stats"] = self.kv_cache.get_stats()
stats["batch_stats"] = self.batcher.stats
stats["is_running"] = self.is_running
return stats
def cancel_request(self, request_id: str) -> bool:
"""取消请求"""
if request_id in self.active_requests:
del self.active_requests[request_id]
logger.info(f" 取消请求: {request_id}")
return True
return False
# 使用示例和测试代码
if __name__ == "__main__":
print(" 流式LLM推理引擎测试")
print("=" * 50)
# 创建流式引擎
streaming_llm = StreamingLLMEngine(
model_name="gpt2",
device="auto"
)
# 启动引擎
streaming_llm.start_engine()
# 准备测试消息
test_messages = [
ChatMessage(role="system", content="你是一个有帮助的AI助手。"),
ChatMessage(role="user", content="请介绍一下人工智能的发展历程。")
]
# 回调函数
generated_text = []
def on_token_generated(response: StreamingResponse):
if not response.is_final:
generated_text.append(response.token)
print(f"Token {response.position}: '{response.token}'", end="", flush=True)
else:
print(f"\n✅ 生成完成,总计 {response.total_tokens} tokens")
print(f"延迟: {response.latency:.2f}s")
try:
# 开始流式生成
print(" 开始流式生成...")
request_id = streaming_llm.generate_stream(
messages=test_messages,
callback=on_token_generated,
max_new_tokens=100,
temperature=0.7
)
print(f"请求ID: {request_id}")
print("生成结果: ", end="")
# 等待生成完成
time.sleep(10)
# 显示统计信息
stats = streaming_llm.get_statistics()
print(f"\n 引擎统计:")
print(f" 总请求数: {stats['total_requests']}")
print(f" 总Token数: {stats['total_tokens']}")
print(f" 平均延迟: {stats['avg_latency']:.2f}s")
print(f" Token生成速度: {stats['tokens_per_second']:.1f} tokens/s")
print(f" 活跃请求: {stats['active_requests']}")
# 缓存统计
cache_stats = stats['cache_stats']
print(f" 缓存命中率: {cache_stats['hit_rate']:.2%}")
print(f" 缓存大小: {cache_stats['cache_size']}")
print(f" 缓存内存: {cache_stats['memory_usage_mb']:.1f}MB")
# 批处理统计
batch_stats = stats['batch_stats']
print(f" 总批次数: {batch_stats['total_batches']}")
print(f" 平均批次大小: {batch_stats['avg_batch_size']:.1f}")
print(f"\n完整生成文本:")
print("".join(generated_text))
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
finally:
streaming_llm.stop_engine()
print(" 流式LLM推理引擎测试完成!")
接下来我们来实现流式TTS合成器!这是系统的"声带",需要能够在接收到文本流时立即开始合成音频,并且支持流式播放。就像一个专业的播音员,能够边读边播,让用户感受到零延迟的语音体验!
# streaming_tts.py - 流式TTS合成器
import asyncio
import threading
import queue
import time
import logging
import tempfile
import io
from typing import Dict, List, Any, Optional, Callable, Generator
from dataclasses import dataclass, field
import numpy as np
from collections import deque
import pygame
import re
import os
from pathlib import Path
# 导入基础TTS模块
from tts_module import TTSManager
logger = logging.getLogger(__name__)
@dataclass
class StreamingTTSConfig:
"""流式TTS配置"""
# 文本处理
sentence_splitters: List[str] = field(default_factory=lambda: ['.', '!', '?', '。', '!', '?'])
min_chunk_length: int = 10 # 最小合成块长度
max_chunk_length: int = 200 # 最大合成块长度
# 音频处理
sample_rate: int = 22050
buffer_duration: float = 2.0 # 音频缓冲时长(秒)
crossfade_duration: float = 0.1 # 交叉淡化时长(秒)
# 性能参数
max_concurrent_synthesis: int = 3 # 最大并发合成数
synthesis_timeout: float = 10.0 # 合成超时时间
# 质量参数
enable_ssml: bool = False # 是否启用SSML
voice_speed: float = 1.0 # 语速
voice_pitch: float = 1.0 # 音调
@dataclass
class TextChunk:
"""文本块"""
chunk_id: int
text: str
timestamp: float
is_final: bool = False
priority: int = 0 # 优先级,数值越小优先级越高
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class AudioChunk:
"""音频块"""
chunk_id: int
audio_data: np.ndarray
sample_rate: int
timestamp: float
duration: float
is_final: bool = False
metadata: Dict[str, Any] = field(default_factory=dict)
class TextProcessor:
"""文本预处理器"""
def __init__(self, config: StreamingTTSConfig):
self.config = config
self.chunk_buffer = ""
self.chunk_counter = 0
# 编译正则表达式
self.sentence_pattern = self._compile_sentence_pattern()
self.punctuation_pattern = re.compile(r'[,.;:,、;:]')
def _compile_sentence_pattern(self) -> re.Pattern:
"""编译句子分割正则表达式"""
splitters = ''.join(re.escape(s) for s in self.config.sentence_splitters)
return re.compile(f'[{splitters}]')
def add_text(self, text: str, is_final: bool = False) -> List[TextChunk]:
"""
添加文本并返回可处理的文本块
Args:
text: 输入文本
is_final: 是否为最终文本
Returns:
可处理的文本块列表
"""
chunks = []
self.chunk_buffer += text
# 查找完整的句子
while True:
# 查找句子边界
match = self.sentence_pattern.search(self.chunk_buffer)
if match:
# 找到句子结束
sentence_end = match.end()
sentence = self.chunk_buffer[:sentence_end].strip()
if len(sentence) >= self.config.min_chunk_length:
chunk = TextChunk(
chunk_id=self.chunk_counter,
text=sentence,
timestamp=time.time(),
is_final=False
)
chunks.append(chunk)
self.chunk_counter += 1
# 更新缓冲区
self.chunk_buffer = self.chunk_buffer[sentence_end:].lstrip()
else:
# 没有找到完整句子
if is_final and self.chunk_buffer.strip():
# 最终文本,处理剩余内容
chunk = TextChunk(
chunk_id=self.chunk_counter,
text=self.chunk_buffer.strip(),
timestamp=time.time(),
is_final=True
)
chunks.append(chunk)
self.chunk_counter += 1
self.chunk_buffer = ""
elif len(self.chunk_buffer) > self.config.max_chunk_length:
# 缓冲区太长,强制分割
# 尝试在标点符号处分割
punct_match = None
for match in self.punctuation_pattern.finditer(self.chunk_buffer):
if match.end() >= self.config.min_chunk_length:
punct_match = match
break
if punct_match:
split_pos = punct_match.end()
else:
split_pos = self.config.max_chunk_length
chunk_text = self.chunk_buffer[:split_pos].strip()
if chunk_text:
chunk = TextChunk(
chunk_id=self.chunk_counter,
text=chunk_text,
timestamp=time.time(),
is_final=False
)
chunks.append(chunk)
self.chunk_counter += 1
self.chunk_buffer = self.chunk_buffer[split_pos:].lstrip()
else:
break
return chunks
def get_buffered_text(self) -> str:
"""获取缓冲区中的文本"""
return self.chunk_buffer
def clear_buffer(self):
"""清空缓冲区"""
self.chunk_buffer = ""
class AudioBuffer:
"""音频缓冲区"""
def __init__(self, config: StreamingTTSConfig):
self.config = config
self.buffer = deque()
self.buffer_duration = 0.0
self.max_buffer_duration = config.buffer_duration
self.is_playing = False
self.lock = threading.Lock()
# 初始化pygame mixer
pygame.mixer.pre_init(
frequency=config.sample_rate,
size=-16,
channels=1,
buffer=1024
)
pygame.mixer.init()
def add_audio(self, audio_chunk: AudioChunk):
"""添加音频块到缓冲区"""
with self.lock:
self.buffer.append(audio_chunk)
self.buffer_duration += audio_chunk.duration
# 限制缓冲区大小
while self.buffer_duration > self.max_buffer_duration:
if self.buffer:
removed_chunk = self.buffer.popleft()
self.buffer_duration -= removed_chunk.duration
else:
break
def get_next_audio(self) -> Optional[AudioChunk]:
"""获取下一个音频块"""
with self.lock:
if self.buffer:
chunk = self.buffer.popleft()
self.buffer_duration -= chunk.duration
return chunk
return None
def get_buffer_status(self) -> Dict[str, Any]:
"""获取缓冲区状态"""
with self.lock:
return {
"buffer_size": len(self.buffer),
"buffer_duration": self.buffer_duration,
"max_buffer_duration": self.max_buffer_duration,
"is_playing": self.is_playing
}
def clear_buffer(self):
"""清空缓冲区"""
with self.lock:
self.buffer.clear()
self.buffer_duration = 0.0
class StreamingTTSSynthesizer:
"""流式TTS合成器"""
def __init__(
self,
tts_manager: TTSManager,
config: Optional[StreamingTTSConfig] = None
):
"""
初始化流式TTS合成器
Args:
tts_manager: TTS管理器实例
config: 流式TTS配置
"""
self.tts_manager = tts_manager
self.config = config or StreamingTTSConfig()
# 初始化组件
self.text_processor = TextProcessor(self.config)
self.audio_buffer = AudioBuffer(self.config)
# 处理状态
self.is_running = False
self.synthesis_threads = []
self.playback_thread = None
# 队列管理
self.text_queue = queue.PriorityQueue()
self.synthesis_queue = queue.Queue()
# 统计信息
self.stats = {
"total_text_chunks": 0,
"total_audio_chunks": 0,
"avg_synthesis_time": 0,
"synthesis_times": deque(maxlen=100),
"first_audio_latency": None,
"total_synthesis_latency": 0
}
# 回调函数
self.callbacks = {
"on_audio_ready": [],
"on_synthesis_complete": [],
"on_playback_start": [],
"on_playback_complete": [],
"on_error": []
}
logger.info(" 流式TTS合成器初始化完成")
def add_callback(self, event: str, callback: Callable):
"""添加回调函数"""
if event in self.callbacks:
self.callbacks[event].append(callback)
def _trigger_callbacks(self, event: str, data: Any = None):
"""触发回调函数"""
for callback in self.callbacks.get(event, []):
try:
callback(data)
except Exception as e:
logger.error(f"❌ 回调函数执行失败 ({event}): {e}")
def start_synthesis(self):
"""启动合成器"""
if self.is_running:
logger.warning("⚠️ 合成器已在运行")
return
self.is_running = True
# 启动合成线程池
for i in range(self.config.max_concurrent_synthesis):
thread = threading.Thread(
target=self._synthesis_worker,
daemon=True,
name=f"TTS-Synthesis-{i}"
)
thread.start()
self.synthesis_threads.append(thread)
# 启动播放线程
self.playback_thread = threading.Thread(
target=self._playback_worker,
daemon=True,
name="TTS-Playback"
)
self.playback_thread.start()
logger.info(" 流式TTS合成器已启动")
def stop_synthesis(self):
"""停止合成器"""
if not self.is_running:
return
self.is_running = False
# 等待线程结束
for thread in self.synthesis_threads:
thread.join(timeout=2)
if self.playback_thread:
self.playback_thread.join(timeout=2)
# 清理资源
self.audio_buffer.clear_buffer()
pygame.mixer.quit()
logger.info(" 流式TTS合成器已停止")
def add_text(self, text: str, is_final: bool = False, priority: int = 0):
"""
添加文本进行合成
Args:
text: 输入文本
is_final: 是否为最终文本
priority: 优先级
"""
if not self.is_running:
logger.warning("⚠️ 合成器未启动")
return
# 处理文本
text_chunks = self.text_processor.add_text(text, is_final)
# 添加到队列
for chunk in text_chunks:
chunk.priority = priority
self.text_queue.put((priority, chunk.chunk_id, chunk))
self.stats["total_text_chunks"] += 1
logger.debug(f" 添加文本块: {len(text_chunks)} 个")
def _synthesis_worker(self):
"""合成工作线程"""
logger.debug(f" 合成工作线程启动: {threading.current_thread().name}")
while self.is_running:
try:
# 获取文本块
priority, chunk_id, text_chunk = self.text_queue.get(timeout=0.1)
# 执行合成
self._synthesize_chunk(text_chunk)
except queue.Empty:
continue
except Exception as e:
logger.error(f"❌ 合成工作线程错误: {e}")
self._trigger_callbacks("on_error", {"error": str(e)})
def _synthesize_chunk(self, text_chunk: TextChunk):
"""合成文本块"""
start_time = time.time()
try:
logger.debug(f" 开始合成: {text_chunk.text[:30]}...")
# 进行TTS合成
audio_data = self.tts_manager.synthesize(text_chunk.text)
if audio_data:
# 转换音频数据
audio_array = self._convert_audio_data(audio_data)
# 创建音频块
audio_chunk = AudioChunk(
chunk_id=text_chunk.chunk_id,
audio_data=audio_array,
sample_rate=self.config.sample_rate,
timestamp=time.time(),
duration=len(audio_array) / self.config.sample_rate,
is_final=text_chunk.is_final
)
# 添加到音频缓冲区
self.audio_buffer.add_audio(audio_chunk)
# 更新统计
synthesis_time = time.time() - start_time
self.stats["synthesis_times"].append(synthesis_time)
self.stats["total_audio_chunks"] += 1
if self.stats["synthesis_times"]:
self.stats["avg_synthesis_time"] = np.mean(list(self.stats["synthesis_times"]))
# 记录首次音频延迟
if self.stats["first_audio_latency"] is None:
self.stats["first_audio_latency"] = synthesis_time
self._trigger_callbacks("on_audio_ready", audio_chunk)
logger.debug(f"✅ 合成完成: {text_chunk.chunk_id}, 耗时: {synthesis_time:.3f}s")
except Exception as e:
logger.error(f"❌ 文本块合成失败: {e}")
self._trigger_callbacks("on_error", {
"error": str(e),
"chunk_id": text_chunk.chunk_id,
"text": text_chunk.text
})
def _convert_audio_data(self, audio_data: bytes) -> np.ndarray:
"""转换音频数据格式"""
try:
# 保存到临时文件
with tempfile.NamedTemporaryFile(suffix='.mp3', delete=False) as temp_file:
temp_file.write(audio_data)
temp_path = temp_file.name
# 使用pygame加载音频
sound = pygame.mixer.Sound(temp_path)
# 获取原始音频数据
raw_data = pygame.sndarray.array(sound)
# 转换为float32格式
if raw_data.dtype == np.int16:
audio_array = raw_data.astype(np.float32) / 32767.0
else:
audio_array = raw_data.astype(np.float32)
# 确保是单声道
if len(audio_array.shape) > 1:
audio_array = np.mean(audio_array, axis=1)
# 清理临时文件
os.unlink(temp_path)
return audio_array
except Exception as e:
logger.error(f"❌ 音频数据转换失败: {e}")
return np.array([])
def _playback_worker(self):
"""播放工作线程"""
logger.debug(" 播放工作线程启动")
self._trigger_callbacks("on_playback_start", None)
while self.is_running:
try:
# 获取音频块
audio_chunk = self.audio_buffer.get_next_audio()
if audio_chunk:
# 播放音频
self._play_audio_chunk(audio_chunk)
if audio_chunk.is_final:
self._trigger_callbacks("on_synthesis_complete", None)
else:
time.sleep(0.01) # 没有音频时短暂等待
except Exception as e:
logger.error(f"❌ 播放工作线程错误: {e}")
time.sleep(0.1)
self._trigger_callbacks("on_playback_complete", None)
def _play_audio_chunk(self, audio_chunk: AudioChunk):
"""播放音频块"""
try:
# 转换为int16格式
audio_int16 = (audio_chunk.audio_data * 32767).astype(np.int16)
# 创建pygame Sound对象
sound = pygame.sndarray.make_sound(audio_int16)
# 播放音频
sound.play()
# 等待播放完成
while pygame.mixer.get_busy():
time.sleep(0.01)
logger.debug(f" 播放完成: chunk {audio_chunk.chunk_id}")
except Exception as e:
logger.error(f"❌ 音频播放失败: {e}")
def get_synthesis_status(self) -> Dict[str, Any]:
"""获取合成状态"""
status = {
"is_running": self.is_running,
"text_queue_size": self.text_queue.qsize(),
"buffered_text": self.text_processor.get_buffered_text(),
"audio_buffer": self.audio_buffer.get_buffer_status(),
"stats": self.stats.copy()
}
return status
def clear_queues(self):
"""清空所有队列"""
# 清空文本队列
while not self.text_queue.empty():
try:
self.text_queue.get_nowait()
except queue.Empty:
break
# 清空音频缓冲区
self.audio_buffer.clear_buffer()
# 清空文本处理器缓冲区
self.text_processor.clear_buffer()
logger.info("️ 所有队列已清空")
# 使用示例和测试代码
if __name__ == "__main__":
print(" 流式TTS合成器测试")
print("=" * 50)
# 创建TTS管理器
from tts_module import TTSManager
tts_manager = TTSManager()
# 创建流式合成器
streaming_tts = StreamingTTSSynthesizer(tts_manager)
# 添加回调函数
def on_audio_ready(audio_chunk):
print(f" 音频就绪: chunk {audio_chunk.chunk_id}, 时长: {audio_chunk.duration:.2f}s")
def on_synthesis_complete():
print("✅ 合成完成")
def on_playback_start():
print(" 开始播放")
def on_playback_complete():
print(" 播放完成")
def on_error(error_data):
print(f"❌ 错误: {error_data}")
streaming_tts.add_callback("on_audio_ready", on_audio_ready)
streaming_tts.add_callback("on_synthesis_complete", on_synthesis_complete)
streaming_tts.add_callback("on_playback_start", on_playback_start)
streaming_tts.add_callback("on_playback_complete", on_playback_complete)
streaming_tts.add_callback("on_error", on_error)
try:
# 启动合成器
streaming_tts.start_synthesis()
# 测试文本
test_texts = [
"你好,我是AI语音助手。",
"今天天气很好,适合出门散步。",
"人工智能技术正在快速发展,",
"为我们的生活带来了很多便利。",
"感谢您的使用,再见!"
]
# 模拟流式文本输入
print(" 开始流式文本输入...")
for i, text in enumerate(test_texts):
print(f"输入文本 {i+1}: {text}")
is_final = (i == len(test_texts) - 1)
streaming_tts.add_text(text, is_final=is_final)
time.sleep(1) # 模拟流式输入间隔
# 等待处理完成
print("⏳ 等待处理完成...")
time.sleep(10)
# 显示状态
status = streaming_tts.get_synthesis_status()
print(f"\n 合成状态:")
print(f" 文本队列大小: {status['text_queue_size']}")
print(f" 缓冲文本: '{status['buffered_text']}'")
print(f" 音频缓冲: {status['audio_buffer']}")
stats = status['stats']
print(f" 总文本块: {stats['total_text_chunks']}")
print(f" 总音频块: {stats['total_audio_chunks']}")
print(f" 平均合成时间: {stats['avg_synthesis_time']:.3f}s")
print(f" 首次音频延迟: {stats['first_audio_latency']:.3f}s")
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
finally:
streaming_tts.stop_synthesis()
print(" 流式TTS合成器测试完成!")
怎么样今天的内容还满意吗?再次感谢朋友们的观看,关注GZH:凡人的AI工具箱,回复666,送您价值199的AI大礼包。最后,祝您早日实现财务自由,还请给个赞,谢谢!