PyTorch深度学习框架60天进阶学习计划 - 第58天端到端对话系统(二):流式优化与生产部署

PyTorch深度学习框架60天进阶学习计划 - 第58天端到端对话系统(二):流式优化与生产部署

欢迎回来,我的AI工程师朋友们! 经过第一部分的学习,我们已经搭建了一个基础的端到端对话系统。但是,就像一辆刚下生产线的汽车还需要精心调校才能上路一样,我们的对话系统也需要进一步优化才能在真实场景中发挥最大价值!

今天我们要做的事情更加激动人心——我们要让系统变得更快、更稳定、更智能!想象一下,当用户和你的AI助手对话时,感受不到任何延迟,就像和真人聊天一样自然流畅。这就是我们今天要实现的目标!

在这个技术深水区的探索中,我们将掌握流式推理、并发优化、生产部署等企业级技能,让你的对话系统从"能跑"升级到"跑得快、跑得稳、跑得久"!


️ 第二部分:流式推理优化与生产级部署

1. 流式推理架构设计

首先,让我们来理解什么是流式推理。传统的批处理模式就像工厂流水线,必须等所有零件准备好才能开始组装;而流式推理则像一个高效的厨师,一边准备食材,一边开始烹饪,大大减少了等待时间!

2. 流式推理技术选型与优化策略

在构建高性能的流式系统之前,我们需要深入了解各种技术方案的优劣势。这就像选择合适的工具一样,不同的场景需要不同的解决方案。

流式推理优化技术对比分析

流式ASR技术对比

技术方案 实时性 准确率 资源占用 实现复杂度 适用场景
Whisper实时流式 中等 95%+ 中等 高质量要求
Wav2Vec2流式 90%+ 中等 实时交互
在线ASR API 98%+ 网络稳定环境
自定义轻量模型 极高 85%+ 极高 边缘设备

流式LLM推理优化

优化策略 延迟改善 实现难度 资源需求 效果稳定性
KV-Cache优化 30-50% 中等 内存+20%
动态批处理 40-60% GPU利用率+30%
模型量化 20-40% 内存-50% 中等
流水线并行 50-70% 极高 多GPU
投机解码 60-80% 内存+40% 中等

流式TTS合成策略

技术方案 首字节延迟 音质 资源占用 适用场景
FastSpeech2流式 <200ms 优秀 中等 生产环境
VITS流式 <500ms 极优 高质量场景
神经声码器 <100ms 良好 实时应用
混合合成 <300ms 优秀 中等 平衡方案

并发处理架构

并发模式 吞吐量提升 延迟影响 复杂度 错误处理
多线程 2-4x 略增加 中等 复杂
异步IO 5-10x 显著降低 复杂
进程池 线性扩展 显著增加 简单
协程 10-50x 最低 中等 中等

性能优化指标

优化目标 目标值 当前值 改善空间 关键技术
首字节延迟 <500ms 1-3s 60-80% VAD+流式ASR
端到端延迟 <2s 3-8s 50-70% 全链路流式
并发处理 100+ 10-20 5-10x 异步架构
内存效率 <4GB 8-16GB 50-75% 模型量化
GPU利用率 >80% 30-50% 60-160% 批处理优化

生产环境部署考量

部署方式 扩展性 稳定性 运维复杂度 成本
单机部署 中等
容器化部署 中等 中等 中等
微服务架构
Serverless 极高 中等 弹性

实时性能要求分级

应用场景 延迟要求 并发要求 可靠性要求 推荐方案
演示原型 <5s 1-10人 一般 单机+基础优化
企业内部 <3s 50-100人 容器+流式优化
商业产品 <2s 1000+人 极高 微服务+全优化
实时客服 <1s 10000+人 极高 专用硬件+定制优化

3. 流式ASR处理器实现

让我们开始实现真正的流式ASR处理器!这个模块就像一个敏感的"耳朵",能够实时捕捉和处理语音信号,并且支持VAD(语音活动检测)来智能识别何时开始和结束语音。

# streaming_asr.py - 流式ASR处理器
import torch
import numpy as np
import threading
import queue
import time
import logging
from typing import Optional, Callable, Dict, Any, List
from dataclasses import dataclass
import webrtcvad
import librosa
import soundfile as sf
from collections import deque
import asyncio

# 导入基础ASR模块
from asr_module import ASREngine

logger = logging.getLogger(__name__)

@dataclass
class VADConfig:
    """VAD配置参数"""
    aggressiveness: int = 3  # VAD敏感度 (0-3)
    frame_duration: int = 30  # 帧长度 (10, 20, 30 ms)
    sample_rate: int = 16000
    silence_threshold: int = 20  # 连续静音帧数阈值
    speech_threshold: int = 5   # 连续语音帧数阈值

@dataclass
class StreamingChunk:
    """流式音频块"""
    audio_data: np.ndarray
    timestamp: float
    is_speech: bool
    chunk_id: int
    metadata: Optional[Dict[str, Any]] = None

class StreamingVAD:
    """流式语音活动检测"""
    
    def __init__(self, config: VADConfig):
        self.config = config
        self.vad = webrtcvad.Vad(config.aggressiveness)
        
        # 状态管理
        self.is_speaking = False
        self.silence_frames = 0
        self.speech_frames = 0
        
        # 音频缓冲
        self.audio_buffer = deque(maxlen=100)
        self.speech_buffer = []
        
        logger.info(f" VAD初始化完成,敏感度: {config.aggressiveness}")
    
    def process_frame(self, audio_frame: bytes) -> Dict[str, Any]:
        """
        处理音频帧
        
        Args:
            audio_frame: 原始音频帧数据
            
        Returns:
            处理结果字典
        """
        # VAD检测
        is_speech = self.vad.is_speech(audio_frame, self.config.sample_rate)
        
        result = {
            "is_speech": is_speech,
            "is_speaking": self.is_speaking,
            "state_change": False,
            "audio_data": None
        }
        
        if is_speech:
            self.speech_frames += 1
            self.silence_frames = 0
            
            # 添加到语音缓冲区
            self.speech_buffer.append(audio_frame)
            
            # 检查是否开始说话
            if not self.is_speaking and self.speech_frames >= self.config.speech_threshold:
                self.is_speaking = True
                result["state_change"] = True
                result["event"] = "speech_start"
                logger.debug("️  检测到语音开始")
        else:
            self.silence_frames += 1
            self.speech_frames = 0
            
            # 如果正在说话,继续添加到缓冲区(可能是短暂停顿)
            if self.is_speaking:
                self.speech_buffer.append(audio_frame)
            
            # 检查是否停止说话
            if self.is_speaking and self.silence_frames >= self.config.silence_threshold:
                self.is_speaking = False
                result["state_change"] = True
                result["event"] = "speech_end"
                
                # 返回完整的语音数据
                if self.speech_buffer:
                    result["audio_data"] = b"".join(self.speech_buffer)
                    self.speech_buffer.clear()
                
                logger.debug(" 检测到语音结束")
        
        return result

class StreamingASRProcessor:
    """流式ASR处理器"""
    
    def __init__(
        self,
        asr_engine: ASREngine,
        vad_config: Optional[VADConfig] = None,
        chunk_duration: float = 0.5,
        overlap_duration: float = 0.1,
        max_chunk_size: int = 16000 * 10  # 10秒最大块大小
    ):
        """
        初始化流式ASR处理器
        
        Args:
            asr_engine: ASR引擎实例
            vad_config: VAD配置
            chunk_duration: 音频块持续时间(秒)
            overlap_duration: 重叠时间(秒)
            max_chunk_size: 最大块大小
        """
        self.asr_engine = asr_engine
        self.vad_config = vad_config or VADConfig()
        self.chunk_duration = chunk_duration
        self.overlap_duration = overlap_duration
        self.max_chunk_size = max_chunk_size
        
        # 初始化VAD
        self.vad = StreamingVAD(self.vad_config)
        
        # 流式处理状态
        self.is_processing = False
        self.audio_queue = queue.Queue()
        self.result_queue = queue.Queue()
        
        # 音频缓冲和上下文
        self.audio_buffer = deque()
        self.context_buffer = deque(maxlen=5)  # 保留上下文
        
        # 处理线程
        self.processing_thread = None
        
        # 回调函数
        self.callbacks = {
            "on_partial_result": [],
            "on_final_result": [],
            "on_speech_start": [],
            "on_speech_end": [],
            "on_error": []
        }
        
        # 性能统计
        self.stats = {
            "total_chunks": 0,
            "processing_times": [],
            "recognition_times": [],
            "queue_sizes": []
        }
        
        logger.info(" 流式ASR处理器初始化完成")
    
    def add_callback(self, event: str, callback: Callable):
        """添加回调函数"""
        if event in self.callbacks:
            self.callbacks[event].append(callback)
    
    def _trigger_callbacks(self, event: str, data: Any = None):
        """触发回调函数"""
        for callback in self.callbacks.get(event, []):
            try:
                callback(data)
            except Exception as e:
                logger.error(f"❌ 回调函数执行失败 ({event}): {e}")
    
    def start_streaming(self):
        """开始流式处理"""
        if self.is_processing:
            logger.warning("⚠️  流式处理已在运行")
            return
        
        self.is_processing = True
        self.processing_thread = threading.Thread(
            target=self._processing_loop,
            daemon=True
        )
        self.processing_thread.start()
        
        logger.info(" 流式ASR处理已启动")
    
    def stop_streaming(self):
        """停止流式处理"""
        if not self.is_processing:
            return
        
        self.is_processing = False
        if self.processing_thread:
            self.processing_thread.join(timeout=2)
        
        logger.info(" 流式ASR处理已停止")
    
    def add_audio_chunk(self, audio_data: np.ndarray, timestamp: Optional[float] = None):
        """
        添加音频块到处理队列
        
        Args:
            audio_data: 音频数据
            timestamp: 时间戳
        """
        if not self.is_processing:
            logger.warning("⚠️  流式处理未启动")
            return
        
        chunk = StreamingChunk(
            audio_data=audio_data,
            timestamp=timestamp or time.time(),
            is_speech=False,
            chunk_id=self.stats["total_chunks"]
        )
        
        self.audio_queue.put(chunk)
        self.stats["total_chunks"] += 1
        self.stats["queue_sizes"].append(self.audio_queue.qsize())
    
    def _processing_loop(self):
        """主处理循环"""
        logger.info(" 开始流式处理循环")
        
        while self.is_processing:
            try:
                # 获取音频块
                chunk = self.audio_queue.get(timeout=0.1)
                
                # 处理音频块
                self._process_audio_chunk(chunk)
                
            except queue.Empty:
                continue
            except Exception as e:
                logger.error(f"❌ 处理循环错误: {e}")
                self._trigger_callbacks("on_error", {"error": str(e)})
    
    def _process_audio_chunk(self, chunk: StreamingChunk):
        """处理单个音频块"""
        start_time = time.time()
        
        try:
            # 转换音频格式用于VAD
            audio_int16 = (chunk.audio_data * 32767).astype(np.int16)
            audio_bytes = audio_int16.tobytes()
            
            # VAD检测
            vad_result = self.vad.process_frame(audio_bytes)
            chunk.is_speech = vad_result["is_speech"]
            
            # 处理状态变化
            if vad_result["state_change"]:
                if vad_result["event"] == "speech_start":
                    self._trigger_callbacks("on_speech_start", chunk)
                elif vad_result["event"] == "speech_end":
                    # 处理完整语音段
                    if vad_result["audio_data"]:
                        self._process_speech_segment(vad_result["audio_data"], chunk.timestamp)
                    self._trigger_callbacks("on_speech_end", chunk)
            
            # 如果正在说话,累积音频数据
            if self.vad.is_speaking:
                self.audio_buffer.append(chunk.audio_data)
                
                # 检查是否需要进行增量识别
                if len(self.audio_buffer) >= int(self.chunk_duration * 16000 / 1024):
                    self._perform_incremental_recognition()
            
            # 记录处理时间
            processing_time = time.time() - start_time
            self.stats["processing_times"].append(processing_time)
            
        except Exception as e:
            logger.error(f"❌ 音频块处理失败: {e}")
            self._trigger_callbacks("on_error", {"error": str(e), "chunk_id": chunk.chunk_id})
    
    def _process_speech_segment(self, audio_bytes: bytes, timestamp: float):
        """处理完整的语音段"""
        try:
            # 转换音频数据
            audio_array = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32767.0
            
            # 进行ASR识别
            recognition_start = time.time()
            result = self.asr_engine.transcribe_audio_data(
                audio_array,
                sample_rate=self.vad_config.sample_rate
            )
            recognition_time = time.time() - recognition_start
            
            self.stats["recognition_times"].append(recognition_time)
            
            # 触发最终结果回调
            final_result = {
                "text": result.get("text", ""),
                "confidence": result.get("confidence", 0.0),
                "timestamp": timestamp,
                "recognition_time": recognition_time,
                "type": "final"
            }
            
            self._trigger_callbacks("on_final_result", final_result)
            
            logger.info(f"✅ 语音段识别完成: {final_result['text'][:50]}...")
            
        except Exception as e:
            logger.error(f"❌ 语音段处理失败: {e}")
            self._trigger_callbacks("on_error", {"error": str(e)})
    
    def _perform_incremental_recognition(self):
        """执行增量识别"""
        if not self.audio_buffer:
            return
        
        try:
            # 合并音频数据
            combined_audio = np.concatenate(list(self.audio_buffer))
            
            # 限制最大长度
            if len(combined_audio) > self.max_chunk_size:
                combined_audio = combined_audio[-self.max_chunk_size:]
            
            # 进行增量识别
            recognition_start = time.time()
            result = self.asr_engine.transcribe_audio_data(
                combined_audio,
                sample_rate=self.vad_config.sample_rate
            )
            recognition_time = time.time() - recognition_start
            
            # 触发部分结果回调
            partial_result = {
                "text": result.get("text", ""),
                "confidence": result.get("confidence", 0.0),
                "timestamp": time.time(),
                "recognition_time": recognition_time,
                "type": "partial"
            }
            
            self._trigger_callbacks("on_partial_result", partial_result)
            
            # 清理旧的音频数据,保留一些重叠
            overlap_samples = int(self.overlap_duration * 16000)
            if len(combined_audio) > overlap_samples:
                self.audio_buffer.clear()
                self.audio_buffer.append(combined_audio[-overlap_samples:])
            
        except Exception as e:
            logger.error(f"❌ 增量识别失败: {e}")
    
    def get_statistics(self) -> Dict[str, Any]:
        """获取处理统计信息"""
        stats = self.stats.copy()
        
        if stats["processing_times"]:
            stats["avg_processing_time"] = np.mean(stats["processing_times"])
            stats["max_processing_time"] = np.max(stats["processing_times"])
        
        if stats["recognition_times"]:
            stats["avg_recognition_time"] = np.mean(stats["recognition_times"])
            stats["max_recognition_time"] = np.max(stats["recognition_times"])
        
        if stats["queue_sizes"]:
            stats["avg_queue_size"] = np.mean(stats["queue_sizes"])
            stats["max_queue_size"] = np.max(stats["queue_sizes"])
        
        stats["is_processing"] = self.is_processing
        stats["vad_is_speaking"] = self.vad.is_speaking
        stats["buffer_size"] = len(self.audio_buffer)
        
        return stats

class AsyncStreamingASR:
    """异步流式ASR处理器"""
    
    def __init__(self, asr_processor: StreamingASRProcessor):
        self.asr_processor = asr_processor
        self.loop = None
        self.tasks = set()
    
    async def start_async_processing(self):
        """启动异步处理"""
        self.loop = asyncio.get_event_loop()
        self.asr_processor.start_streaming()
        
        # 启动监控任务
        monitor_task = self.loop.create_task(self._monitor_performance())
        self.tasks.add(monitor_task)
        
        logger.info(" 异步流式ASR处理已启动")
    
    async def _monitor_performance(self):
        """性能监控任务"""
        while self.asr_processor.is_processing:
            try:
                stats = self.asr_processor.get_statistics()
                
                # 检查性能指标
                if stats.get("avg_queue_size", 0) > 10:
                    logger.warning(f"⚠️  音频队列积压: {stats['avg_queue_size']:.1f}")
                
                if stats.get("avg_processing_time", 0) > 0.1:
                    logger.warning(f"⚠️  处理延迟过高: {stats['avg_processing_time']:.3f}s")
                
                await asyncio.sleep(5)  # 每5秒检查一次
                
            except Exception as e:
                logger.error(f"❌ 性能监控错误: {e}")
                await asyncio.sleep(1)
    
    async def stop_async_processing(self):
        """停止异步处理"""
        self.asr_processor.stop_streaming()
        
        # 取消所有任务
        for task in self.tasks:
            task.cancel()
        
        # 等待任务完成
        if self.tasks:
            await asyncio.gather(*self.tasks, return_exceptions=True)
        
        logger.info(" 异步流式ASR处理已停止")

# 使用示例和测试代码
if __name__ == "__main__":
    print(" 流式ASR处理器测试")
    print("=" * 50)
    
    # 创建ASR引擎
    from asr_module import ASREngine
    asr_engine = ASREngine(model_name="base")
    
    # 创建流式处理器
    streaming_asr = StreamingASRProcessor(asr_engine)
    
    # 添加回调函数
    def on_partial_result(result):
        print(f" 部分结果: {result['text']}")
    
    def on_final_result(result):
        print(f"✅ 最终结果: {result['text']}")
        print(f"   置信度: {result['confidence']:.2f}")
        print(f"   识别耗时: {result['recognition_time']:.3f}s")
    
    def on_speech_start(chunk):
        print("️  开始说话")
    
    def on_speech_end(chunk):
        print(" 说话结束")
    
    streaming_asr.add_callback("on_partial_result", on_partial_result)
    streaming_asr.add_callback("on_final_result", on_final_result)
    streaming_asr.add_callback("on_speech_start", on_speech_start)
    streaming_asr.add_callback("on_speech_end", on_speech_end)
    
    # 启动流式处理
    streaming_asr.start_streaming()
    
    # 模拟音频输入(实际使用中会从麦克风获取)
    print(" 开始模拟音频输入...")
    try:
        # 生成一些测试音频数据
        import numpy as np
        duration = 5.0  # 5秒测试音频
        sample_rate = 16000
        
        # 生成带有语音模式的测试信号
        t = np.linspace(0, duration, int(sample_rate * duration))
        
        # 模拟语音: 前1秒静音,中间3秒有声音,后1秒静音
        audio_signal = np.zeros_like(t)
        speech_start = int(sample_rate * 1.0)
        speech_end = int(sample_rate * 4.0)
        
        # 添加语音信号(简单的正弦波组合)
        audio_signal[speech_start:speech_end] = (
            0.3 * np.sin(2 * np.pi * 440 * t[speech_start:speech_end]) +
            0.2 * np.sin(2 * np.pi * 880 * t[speech_start:speech_end])
        )
        
        # 分块发送音频数据
        chunk_size = int(sample_rate * 0.1)  # 100ms块
        for i in range(0, len(audio_signal), chunk_size):
            chunk = audio_signal[i:i+chunk_size]
            streaming_asr.add_audio_chunk(chunk, time.time())
            time.sleep(0.05)  # 模拟实时输入
        
        # 等待处理完成
        time.sleep(2)
        
        # 显示统计信息
        stats = streaming_asr.get_statistics()
        print(f"\n 处理统计:")
        print(f"  总音频块: {stats['total_chunks']}")
        print(f"  平均处理时间: {stats.get('avg_processing_time', 0):.3f}s")
        print(f"  平均识别时间: {stats.get('avg_recognition_time', 0):.3f}s")
        print(f"  平均队列大小: {stats.get('avg_queue_size', 0):.1f}")
        
    except KeyboardInterrupt:
        print("\n 用户中断测试")
    except Exception as e:
        print(f"❌ 测试失败: {e}")
    finally:
        streaming_asr.stop_streaming()
    
    print(" 流式ASR处理器测试完成!")

4. 流式LLM推理引擎

现在让我们实现流式LLM推理引擎!这是整个系统的"大脑中枢",需要能够在接收到部分输入时就开始思考,并且能够实时输出推理结果。就像一个聪明的对话者,能够边听边思考,边思考边回答!

# streaming_llm.py - 流式LLM推理引擎
import torch
import asyncio
import threading
import queue
import time
import logging
from typing import Dict, List, Any, Optional, Generator, Callable, Tuple
from dataclasses import dataclass, field
import numpy as np
from collections import deque
import json
import gc

# 导入transformers库
try:
    from transformers import (
        AutoTokenizer, AutoModelForCausalLM,
        GenerationConfig, TextIteratorStreamer,
        StoppingCriteria, StoppingCriteriaList
    )
    TRANSFORMERS_AVAILABLE = True
except ImportError:
    TRANSFORMERS_AVAILABLE = False

# 导入基础模块
from llm_module import LLMManager, ChatMessage, GenerationParams

logger = logging.getLogger(__name__)

@dataclass
class StreamingConfig:
    """流式推理配置"""
    # 生成参数
    max_new_tokens: int = 512
    temperature: float = 0.7
    top_p: float = 0.9
    top_k: int = 50
    do_sample: bool = True
    repetition_penalty: float = 1.1
    
    # 流式参数
    stream_chunk_size: int = 1  # 每次输出的token数
    buffer_size: int = 10  # 输出缓冲区大小
    timeout: float = 30.0  # 生成超时时间
    
    # 优化参数
    use_cache: bool = True
    use_flash_attention: bool = False
    use_kv_cache_optimization: bool = True
    
    # 批处理参数
    max_batch_size: int = 4
    dynamic_batching: bool = True
    batch_timeout: float = 0.1

@dataclass
class StreamingRequest:
    """流式推理请求"""
    request_id: str
    messages: List[ChatMessage]
    config: StreamingConfig
    timestamp: float = field(default_factory=time.time)
    callback: Optional[Callable] = None
    metadata: Dict[str, Any] = field(default_factory=dict)

@dataclass
class StreamingResponse:
    """流式推理响应"""
    request_id: str
    token: str
    is_final: bool
    timestamp: float
    position: int
    total_tokens: Optional[int] = None
    latency: Optional[float] = None
    metadata: Dict[str, Any] = field(default_factory=dict)

class KVCacheManager:
    """KV Cache管理器"""
    
    def __init__(self, max_cache_size: int = 1000):
        self.max_cache_size = max_cache_size
        self.cache = {}
        self.usage_times = {}
        self.cache_hits = 0
        self.cache_misses = 0
    
    def get_cache_key(self, input_ids: torch.Tensor) -> str:
        """生成缓存键"""
        return hash(tuple(input_ids.flatten().tolist()))
    
    def get(self, cache_key: str) -> Optional[Tuple]:
        """获取缓存"""
        if cache_key in self.cache:
            self.usage_times[cache_key] = time.time()
            self.cache_hits += 1
            return self.cache[cache_key]
        
        self.cache_misses += 1
        return None
    
    def put(self, cache_key: str, past_key_values: Tuple):
        """存储缓存"""
        # 如果缓存已满,移除最久未使用的
        if len(self.cache) >= self.max_cache_size:
            oldest_key = min(self.usage_times.keys(), 
                           key=lambda k: self.usage_times[k])
            del self.cache[oldest_key]
            del self.usage_times[oldest_key]
        
        self.cache[cache_key] = past_key_values
        self.usage_times[cache_key] = time.time()
    
    def get_stats(self) -> Dict[str, Any]:
        """获取缓存统计"""
        total_requests = self.cache_hits + self.cache_misses
        hit_rate = self.cache_hits / max(1, total_requests)
        
        return {
            "cache_size": len(self.cache),
            "max_cache_size": self.max_cache_size,
            "cache_hits": self.cache_hits,
            "cache_misses": self.cache_misses,
            "hit_rate": hit_rate,
            "memory_usage_mb": self._estimate_memory_usage()
        }
    
    def _estimate_memory_usage(self) -> float:
        """估算内存使用量"""
        total_size = 0
        for cache_value in self.cache.values():
            if isinstance(cache_value, tuple):
                for tensor in cache_value:
                    if isinstance(tensor, torch.Tensor):
                        total_size += tensor.numel() * tensor.element_size()
        return total_size / (1024 * 1024)  # MB

class DynamicBatcher:
    """动态批处理器"""
    
    def __init__(self, max_batch_size: int = 4, timeout: float = 0.1):
        self.max_batch_size = max_batch_size
        self.timeout = timeout
        self.pending_requests = []
        self.batch_queue = queue.Queue()
        self.stats = {
            "total_batches": 0,
            "avg_batch_size": 0,
            "batch_sizes": []
        }
    
    def add_request(self, request: StreamingRequest):
        """添加请求到批处理队列"""
        self.pending_requests.append(request)
        
        # 检查是否可以组成批次
        if len(self.pending_requests) >= self.max_batch_size:
            self._create_batch()
    
    def _create_batch(self):
        """创建批次"""
        if not self.pending_requests:
            return
        
        batch_size = min(len(self.pending_requests), self.max_batch_size)
        batch = self.pending_requests[:batch_size]
        self.pending_requests = self.pending_requests[batch_size:]
        
        self.batch_queue.put(batch)
        
        # 更新统计
        self.stats["total_batches"] += 1
        self.stats["batch_sizes"].append(batch_size)
        if self.stats["batch_sizes"]:
            self.stats["avg_batch_size"] = np.mean(self.stats["batch_sizes"])
    
    def get_batch(self, timeout: Optional[float] = None) -> Optional[List[StreamingRequest]]:
        """获取批次"""
        timeout = timeout or self.timeout
        
        try:
            return self.batch_queue.get(timeout=timeout)
        except queue.Empty:
            # 超时后处理剩余请求
            if self.pending_requests:
                self._create_batch()
                try:
                    return self.batch_queue.get_nowait()
                except queue.Empty:
                    pass
            return None

class StreamingLLMEngine:
    """流式LLM推理引擎"""
    
    def __init__(
        self,
        model_name: str = "gpt2",
        device: str = "auto",
        config: Optional[StreamingConfig] = None,
        cache_dir: str = "./models/llm"
    ):
        """
        初始化流式LLM引擎
        
        Args:
            model_name: 模型名称
            device: 设备类型
            config: 流式配置
            cache_dir: 缓存目录
        """
        if not TRANSFORMERS_AVAILABLE:
            raise ImportError("transformers库不可用")
        
        self.model_name = model_name
        self.config = config or StreamingConfig()
        self.cache_dir = cache_dir
        
        # 设备配置
        if device == "auto":
            if torch.cuda.is_available():
                self.device = "cuda"
            elif torch.backends.mps.is_available():
                self.device = "mps"
            else:
                self.device = "cpu"
        else:
            self.device = device
        
        # 加载模型
        self._load_model()
        
        # 初始化组件
        self.kv_cache = KVCacheManager()
        self.batcher = DynamicBatcher(
            max_batch_size=self.config.max_batch_size,
            timeout=self.config.batch_timeout
        )
        
        # 处理状态
        self.is_running = False
        self.processing_thread = None
        
        # 请求管理
        self.active_requests = {}
        self.request_counter = 0
        
        # 性能统计
        self.stats = {
            "total_requests": 0,
            "total_tokens": 0,
            "avg_latency": 0,
            "latencies": deque(maxlen=1000),
            "tokens_per_second": 0,
            "token_rates": deque(maxlen=100)
        }
        
        logger.info(f" 流式LLM引擎初始化完成: {model_name} on {self.device}")
    
    def _load_model(self):
        """加载模型和分词器"""
        try:
            logger.info(" 加载分词器...")
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                cache_dir=self.cache_dir,
                trust_remote_code=True
            )
            
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            logger.info(" 加载模型...")
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                cache_dir=self.cache_dir,
                trust_remote_code=True,
                torch_dtype=torch.float16 if self.device != "cpu" else torch.float32
            )
            
            self.model = self.model.to(self.device)
            self.model.eval()
            
            logger.info("✅ 模型加载成功")
            
        except Exception as e:
            logger.error(f"❌ 模型加载失败: {e}")
            raise
    
    def start_engine(self):
        """启动推理引擎"""
        if self.is_running:
            logger.warning("⚠️  推理引擎已在运行")
            return
        
        self.is_running = True
        self.processing_thread = threading.Thread(
            target=self._processing_loop,
            daemon=True
        )
        self.processing_thread.start()
        
        logger.info(" 流式LLM推理引擎已启动")
    
    def stop_engine(self):
        """停止推理引擎"""
        if not self.is_running:
            return
        
        self.is_running = False
        if self.processing_thread:
            self.processing_thread.join(timeout=5)
        
        logger.info(" 流式LLM推理引擎已停止")
    
    def generate_stream(
        self,
        messages: List[ChatMessage],
        callback: Optional[Callable] = None,
        **kwargs
    ) -> str:
        """
        流式生成文本
        
        Args:
            messages: 消息列表
            callback: 回调函数
            **kwargs: 其他参数
        
        Returns:
            请求ID
        """
        if not self.is_running:
            raise RuntimeError("推理引擎未启动")
        
        # 创建请求
        request_id = f"req_{self.request_counter}_{int(time.time() * 1000)}"
        self.request_counter += 1
        
        # 合并配置
        request_config = StreamingConfig(**{
            **self.config.__dict__,
            **kwargs
        })
        
        request = StreamingRequest(
            request_id=request_id,
            messages=messages,
            config=request_config,
            callback=callback
        )
        
        # 添加到批处理器
        self.batcher.add_request(request)
        self.active_requests[request_id] = request
        
        logger.info(f" 添加生成请求: {request_id}")
        return request_id
    
    def _processing_loop(self):
        """主处理循环"""
        logger.info(" 开始推理处理循环")
        
        while self.is_running:
            try:
                # 获取批次
                batch = self.batcher.get_batch(timeout=0.1)
                if not batch:
                    continue
                
                # 处理批次
                self._process_batch(batch)
                
            except Exception as e:
                logger.error(f"❌ 处理循环错误: {e}")
                time.sleep(0.1)
    
    def _process_batch(self, batch: List[StreamingRequest]):
        """处理批次请求"""
        logger.debug(f" 处理批次,大小: {len(batch)}")
        
        # 对于简化,这里逐个处理(实际可以真正批处理)
        for request in batch:
            try:
                self._process_single_request(request)
            except Exception as e:
                logger.error(f"❌ 请求处理失败 {request.request_id}: {e}")
                self._send_error_response(request, str(e))
    
    def _process_single_request(self, request: StreamingRequest):
        """处理单个请求"""
        start_time = time.time()
        
        try:
            # 构建输入
            input_text = self._format_messages(request.messages)
            
            # 编码输入
            inputs = self.tokenizer.encode(
                input_text,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=4096 - request.config.max_new_tokens
            )
            
            if self.device != "cpu":
                inputs = inputs.to(self.device)
            
            # 检查KV缓存
            cache_key = self.kv_cache.get_cache_key(inputs)
            past_key_values = self.kv_cache.get(cache_key) if request.config.use_kv_cache_optimization else None
            
            # 生成配置
            generation_config = GenerationConfig(
                max_new_tokens=request.config.max_new_tokens,
                temperature=request.config.temperature,
                top_p=request.config.top_p,
                top_k=request.config.top_k,
                do_sample=request.config.do_sample,
                repetition_penalty=request.config.repetition_penalty,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
                use_cache=request.config.use_cache
            )
            
            # 流式生成
            self._generate_streaming(request, inputs, generation_config, past_key_values)
            
            # 更新统计
            total_time = time.time() - start_time
            self.stats["latencies"].append(total_time)
            self.stats["total_requests"] += 1
            
            if self.stats["latencies"]:
                self.stats["avg_latency"] = np.mean(list(self.stats["latencies"]))
            
        except Exception as e:
            logger.error(f"❌ 请求处理异常: {e}")
            self._send_error_response(request, str(e))
        finally:
            # 清理请求
            if request.request_id in self.active_requests:
                del self.active_requests[request.request_id]
    
    def _generate_streaming(
        self,
        request: StreamingRequest,
        input_ids: torch.Tensor,
        generation_config: GenerationConfig,
        past_key_values: Optional[Tuple] = None
    ):
        """执行流式生成"""
        position = 0
        generated_tokens = 0
        current_input = input_ids
        
        with torch.no_grad():
            while generated_tokens < request.config.max_new_tokens:
                try:
                    # 模型推理
                    outputs = self.model(
                        current_input,
                        past_key_values=past_key_values,
                        use_cache=generation_config.use_cache
                    )
                    
                    # 获取logits和新的past_key_values
                    logits = outputs.logits[:, -1, :]
                    if generation_config.use_cache:
                        past_key_values = outputs.past_key_values
                    
                    # 应用生成参数
                    if generation_config.temperature != 1.0:
                        logits = logits / generation_config.temperature
                    
                    if generation_config.top_k > 0:
                        top_k_logits, top_k_indices = torch.topk(logits, generation_config.top_k)
                        logits = torch.full_like(logits, float('-inf'))
                        logits.scatter_(1, top_k_indices, top_k_logits)
                    
                    if generation_config.top_p < 1.0:
                        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                        cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
                        sorted_indices_to_remove = cumulative_probs > generation_config.top_p
                        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                        sorted_indices_to_remove[..., 0] = 0
                        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                        logits[indices_to_remove] = float('-inf')
                    
                    # 采样下一个token
                    if generation_config.do_sample:
                        probs = torch.softmax(logits, dim=-1)
                        next_token = torch.multinomial(probs, num_samples=1)
                    else:
                        next_token = torch.argmax(logits, dim=-1, keepdim=True)
                    
                    # 检查结束条件
                    if next_token.item() == self.tokenizer.eos_token_id:
                        break
                    
                    # 解码token
                    token_text = self.tokenizer.decode(next_token[0], skip_special_tokens=True)
                    
                    # 发送响应
                    response = StreamingResponse(
                        request_id=request.request_id,
                        token=token_text,
                        is_final=False,
                        timestamp=time.time(),
                        position=position,
                        latency=time.time() - request.timestamp
                    )
                    
                    self._send_response(request, response)
                    
                    # 更新状态
                    current_input = next_token
                    position += 1
                    generated_tokens += 1
                    
                    # 更新统计
                    self.stats["total_tokens"] += 1
                    
                    # 计算token生成速度
                    elapsed_time = time.time() - request.timestamp
                    if elapsed_time > 0:
                        token_rate = generated_tokens / elapsed_time
                        self.stats["token_rates"].append(token_rate)
                        if self.stats["token_rates"]:
                            self.stats["tokens_per_second"] = np.mean(list(self.stats["token_rates"]))
                    
                except Exception as e:
                    logger.error(f"❌ 生成过程异常: {e}")
                    break
        
        # 发送最终响应
        final_response = StreamingResponse(
            request_id=request.request_id,
            token="",
            is_final=True,
            timestamp=time.time(),
            position=position,
            total_tokens=generated_tokens,
            latency=time.time() - request.timestamp
        )
        
        self._send_response(request, final_response)
        
        # 更新KV缓存
        if request.config.use_kv_cache_optimization and past_key_values:
            cache_key = self.kv_cache.get_cache_key(input_ids)
            self.kv_cache.put(cache_key, past_key_values)
    
    def _format_messages(self, messages: List[ChatMessage]) -> str:
        """格式化消息"""
        formatted_parts = []
        
        for message in messages:
            if message.role == "system":
                formatted_parts.append(f"System: {message.content}")
            elif message.role == "user":
                formatted_parts.append(f"User: {message.content}")
            elif message.role == "assistant":
                formatted_parts.append(f"Assistant: {message.content}")
        
        formatted_parts.append("Assistant:")
        return "\n".join(formatted_parts)
    
    def _send_response(self, request: StreamingRequest, response: StreamingResponse):
        """发送响应"""
        if request.callback:
            try:
                request.callback(response)
            except Exception as e:
                logger.error(f"❌ 回调函数执行失败: {e}")
    
    def _send_error_response(self, request: StreamingRequest, error: str):
        """发送错误响应"""
        error_response = StreamingResponse(
            request_id=request.request_id,
            token="",
            is_final=True,
            timestamp=time.time(),
            position=0,
            metadata={"error": error}
        )
        self._send_response(request, error_response)
    
    def get_statistics(self) -> Dict[str, Any]:
        """获取统计信息"""
        stats = self.stats.copy()
        stats["active_requests"] = len(self.active_requests)
        stats["cache_stats"] = self.kv_cache.get_stats()
        stats["batch_stats"] = self.batcher.stats
        stats["is_running"] = self.is_running
        
        return stats
    
    def cancel_request(self, request_id: str) -> bool:
        """取消请求"""
        if request_id in self.active_requests:
            del self.active_requests[request_id]
            logger.info(f" 取消请求: {request_id}")
            return True
        return False

# 使用示例和测试代码
if __name__ == "__main__":
    print(" 流式LLM推理引擎测试")
    print("=" * 50)
    
    # 创建流式引擎
    streaming_llm = StreamingLLMEngine(
        model_name="gpt2",
        device="auto"
    )
    
    # 启动引擎
    streaming_llm.start_engine()
    
    # 准备测试消息
    test_messages = [
        ChatMessage(role="system", content="你是一个有帮助的AI助手。"),
        ChatMessage(role="user", content="请介绍一下人工智能的发展历程。")
    ]
    
    # 回调函数
    generated_text = []
    def on_token_generated(response: StreamingResponse):
        if not response.is_final:
            generated_text.append(response.token)
            print(f"Token {response.position}: '{response.token}'", end="", flush=True)
        else:
            print(f"\n✅ 生成完成,总计 {response.total_tokens} tokens")
            print(f"延迟: {response.latency:.2f}s")
    
    try:
        # 开始流式生成
        print(" 开始流式生成...")
        request_id = streaming_llm.generate_stream(
            messages=test_messages,
            callback=on_token_generated,
            max_new_tokens=100,
            temperature=0.7
        )
        
        print(f"请求ID: {request_id}")
        print("生成结果: ", end="")
        
        # 等待生成完成
        time.sleep(10)
        
        # 显示统计信息
        stats = streaming_llm.get_statistics()
        print(f"\n 引擎统计:")
        print(f"  总请求数: {stats['total_requests']}")
        print(f"  总Token数: {stats['total_tokens']}")
        print(f"  平均延迟: {stats['avg_latency']:.2f}s")
        print(f"  Token生成速度: {stats['tokens_per_second']:.1f} tokens/s")
        print(f"  活跃请求: {stats['active_requests']}")
        
        # 缓存统计
        cache_stats = stats['cache_stats']
        print(f"  缓存命中率: {cache_stats['hit_rate']:.2%}")
        print(f"  缓存大小: {cache_stats['cache_size']}")
        print(f"  缓存内存: {cache_stats['memory_usage_mb']:.1f}MB")
        
        # 批处理统计
        batch_stats = stats['batch_stats']
        print(f"  总批次数: {batch_stats['total_batches']}")
        print(f"  平均批次大小: {batch_stats['avg_batch_size']:.1f}")
        
        print(f"\n完整生成文本:")
        print("".join(generated_text))
        
    except Exception as e:
        print(f"❌ 测试失败: {e}")
        import traceback
        traceback.print_exc()
    finally:
        streaming_llm.stop_engine()
    
    print(" 流式LLM推理引擎测试完成!")

5. 流式TTS合成器与音频流处理

接下来我们来实现流式TTS合成器!这是系统的"声带",需要能够在接收到文本流时立即开始合成音频,并且支持流式播放。就像一个专业的播音员,能够边读边播,让用户感受到零延迟的语音体验!

# streaming_tts.py - 流式TTS合成器
import asyncio
import threading
import queue
import time
import logging
import tempfile
import io
from typing import Dict, List, Any, Optional, Callable, Generator
from dataclasses import dataclass, field
import numpy as np
from collections import deque
import pygame
import re
import os
from pathlib import Path

# 导入基础TTS模块
from tts_module import TTSManager

logger = logging.getLogger(__name__)

@dataclass
class StreamingTTSConfig:
    """流式TTS配置"""
    # 文本处理
    sentence_splitters: List[str] = field(default_factory=lambda: ['.', '!', '?', '。', '!', '?'])
    min_chunk_length: int = 10  # 最小合成块长度
    max_chunk_length: int = 200  # 最大合成块长度
    
    # 音频处理
    sample_rate: int = 22050
    buffer_duration: float = 2.0  # 音频缓冲时长(秒)
    crossfade_duration: float = 0.1  # 交叉淡化时长(秒)
    
    # 性能参数
    max_concurrent_synthesis: int = 3  # 最大并发合成数
    synthesis_timeout: float = 10.0  # 合成超时时间
    
    # 质量参数
    enable_ssml: bool = False  # 是否启用SSML
    voice_speed: float = 1.0  # 语速
    voice_pitch: float = 1.0  # 音调

@dataclass
class TextChunk:
    """文本块"""
    chunk_id: int
    text: str
    timestamp: float
    is_final: bool = False
    priority: int = 0  # 优先级,数值越小优先级越高
    metadata: Dict[str, Any] = field(default_factory=dict)

@dataclass
class AudioChunk:
    """音频块"""
    chunk_id: int
    audio_data: np.ndarray
    sample_rate: int
    timestamp: float
    duration: float
    is_final: bool = False
    metadata: Dict[str, Any] = field(default_factory=dict)

class TextProcessor:
    """文本预处理器"""
    
    def __init__(self, config: StreamingTTSConfig):
        self.config = config
        self.chunk_buffer = ""
        self.chunk_counter = 0
        
        # 编译正则表达式
        self.sentence_pattern = self._compile_sentence_pattern()
        self.punctuation_pattern = re.compile(r'[,.;:,、;:]')
        
    def _compile_sentence_pattern(self) -> re.Pattern:
        """编译句子分割正则表达式"""
        splitters = ''.join(re.escape(s) for s in self.config.sentence_splitters)
        return re.compile(f'[{splitters}]')
    
    def add_text(self, text: str, is_final: bool = False) -> List[TextChunk]:
        """
        添加文本并返回可处理的文本块
        
        Args:
            text: 输入文本
            is_final: 是否为最终文本
            
        Returns:
            可处理的文本块列表
        """
        chunks = []
        self.chunk_buffer += text
        
        # 查找完整的句子
        while True:
            # 查找句子边界
            match = self.sentence_pattern.search(self.chunk_buffer)
            
            if match:
                # 找到句子结束
                sentence_end = match.end()
                sentence = self.chunk_buffer[:sentence_end].strip()
                
                if len(sentence) >= self.config.min_chunk_length:
                    chunk = TextChunk(
                        chunk_id=self.chunk_counter,
                        text=sentence,
                        timestamp=time.time(),
                        is_final=False
                    )
                    chunks.append(chunk)
                    self.chunk_counter += 1
                
                # 更新缓冲区
                self.chunk_buffer = self.chunk_buffer[sentence_end:].lstrip()
            else:
                # 没有找到完整句子
                if is_final and self.chunk_buffer.strip():
                    # 最终文本,处理剩余内容
                    chunk = TextChunk(
                        chunk_id=self.chunk_counter,
                        text=self.chunk_buffer.strip(),
                        timestamp=time.time(),
                        is_final=True
                    )
                    chunks.append(chunk)
                    self.chunk_counter += 1
                    self.chunk_buffer = ""
                elif len(self.chunk_buffer) > self.config.max_chunk_length:
                    # 缓冲区太长,强制分割
                    # 尝试在标点符号处分割
                    punct_match = None
                    for match in self.punctuation_pattern.finditer(self.chunk_buffer):
                        if match.end() >= self.config.min_chunk_length:
                            punct_match = match
                            break
                    
                    if punct_match:
                        split_pos = punct_match.end()
                    else:
                        split_pos = self.config.max_chunk_length
                    
                    chunk_text = self.chunk_buffer[:split_pos].strip()
                    if chunk_text:
                        chunk = TextChunk(
                            chunk_id=self.chunk_counter,
                            text=chunk_text,
                            timestamp=time.time(),
                            is_final=False
                        )
                        chunks.append(chunk)
                        self.chunk_counter += 1
                    
                    self.chunk_buffer = self.chunk_buffer[split_pos:].lstrip()
                else:
                    break
        
        return chunks
    
    def get_buffered_text(self) -> str:
        """获取缓冲区中的文本"""
        return self.chunk_buffer
    
    def clear_buffer(self):
        """清空缓冲区"""
        self.chunk_buffer = ""

class AudioBuffer:
    """音频缓冲区"""
    
    def __init__(self, config: StreamingTTSConfig):
        self.config = config
        self.buffer = deque()
        self.buffer_duration = 0.0
        self.max_buffer_duration = config.buffer_duration
        self.is_playing = False
        self.lock = threading.Lock()
        
        # 初始化pygame mixer
        pygame.mixer.pre_init(
            frequency=config.sample_rate,
            size=-16,
            channels=1,
            buffer=1024
        )
        pygame.mixer.init()
        
    def add_audio(self, audio_chunk: AudioChunk):
        """添加音频块到缓冲区"""
        with self.lock:
            self.buffer.append(audio_chunk)
            self.buffer_duration += audio_chunk.duration
            
            # 限制缓冲区大小
            while self.buffer_duration > self.max_buffer_duration:
                if self.buffer:
                    removed_chunk = self.buffer.popleft()
                    self.buffer_duration -= removed_chunk.duration
                else:
                    break
    
    def get_next_audio(self) -> Optional[AudioChunk]:
        """获取下一个音频块"""
        with self.lock:
            if self.buffer:
                chunk = self.buffer.popleft()
                self.buffer_duration -= chunk.duration
                return chunk
            return None
    
    def get_buffer_status(self) -> Dict[str, Any]:
        """获取缓冲区状态"""
        with self.lock:
            return {
                "buffer_size": len(self.buffer),
                "buffer_duration": self.buffer_duration,
                "max_buffer_duration": self.max_buffer_duration,
                "is_playing": self.is_playing
            }
    
    def clear_buffer(self):
        """清空缓冲区"""
        with self.lock:
            self.buffer.clear()
            self.buffer_duration = 0.0

class StreamingTTSSynthesizer:
    """流式TTS合成器"""
    
    def __init__(
        self,
        tts_manager: TTSManager,
        config: Optional[StreamingTTSConfig] = None
    ):
        """
        初始化流式TTS合成器
        
        Args:
            tts_manager: TTS管理器实例
            config: 流式TTS配置
        """
        self.tts_manager = tts_manager
        self.config = config or StreamingTTSConfig()
        
        # 初始化组件
        self.text_processor = TextProcessor(self.config)
        self.audio_buffer = AudioBuffer(self.config)
        
        # 处理状态
        self.is_running = False
        self.synthesis_threads = []
        self.playback_thread = None
        
        # 队列管理
        self.text_queue = queue.PriorityQueue()
        self.synthesis_queue = queue.Queue()
        
        # 统计信息
        self.stats = {
            "total_text_chunks": 0,
            "total_audio_chunks": 0,
            "avg_synthesis_time": 0,
            "synthesis_times": deque(maxlen=100),
            "first_audio_latency": None,
            "total_synthesis_latency": 0
        }
        
        # 回调函数
        self.callbacks = {
            "on_audio_ready": [],
            "on_synthesis_complete": [],
            "on_playback_start": [],
            "on_playback_complete": [],
            "on_error": []
        }
        
        logger.info(" 流式TTS合成器初始化完成")
    
    def add_callback(self, event: str, callback: Callable):
        """添加回调函数"""
        if event in self.callbacks:
            self.callbacks[event].append(callback)
    
    def _trigger_callbacks(self, event: str, data: Any = None):
        """触发回调函数"""
        for callback in self.callbacks.get(event, []):
            try:
                callback(data)
            except Exception as e:
                logger.error(f"❌ 回调函数执行失败 ({event}): {e}")
    
    def start_synthesis(self):
        """启动合成器"""
        if self.is_running:
            logger.warning("⚠️  合成器已在运行")
            return
        
        self.is_running = True
        
        # 启动合成线程池
        for i in range(self.config.max_concurrent_synthesis):
            thread = threading.Thread(
                target=self._synthesis_worker,
                daemon=True,
                name=f"TTS-Synthesis-{i}"
            )
            thread.start()
            self.synthesis_threads.append(thread)
        
        # 启动播放线程
        self.playback_thread = threading.Thread(
            target=self._playback_worker,
            daemon=True,
            name="TTS-Playback"
        )
        self.playback_thread.start()
        
        logger.info(" 流式TTS合成器已启动")
    
    def stop_synthesis(self):
        """停止合成器"""
        if not self.is_running:
            return
        
        self.is_running = False
        
        # 等待线程结束
        for thread in self.synthesis_threads:
            thread.join(timeout=2)
        
        if self.playback_thread:
            self.playback_thread.join(timeout=2)
        
        # 清理资源
        self.audio_buffer.clear_buffer()
        pygame.mixer.quit()
        
        logger.info(" 流式TTS合成器已停止")
    
    def add_text(self, text: str, is_final: bool = False, priority: int = 0):
        """
        添加文本进行合成
        
        Args:
            text: 输入文本
            is_final: 是否为最终文本
            priority: 优先级
        """
        if not self.is_running:
            logger.warning("⚠️  合成器未启动")
            return
        
        # 处理文本
        text_chunks = self.text_processor.add_text(text, is_final)
        
        # 添加到队列
        for chunk in text_chunks:
            chunk.priority = priority
            self.text_queue.put((priority, chunk.chunk_id, chunk))
            self.stats["total_text_chunks"] += 1
        
        logger.debug(f" 添加文本块: {len(text_chunks)} 个")
    
    def _synthesis_worker(self):
        """合成工作线程"""
        logger.debug(f" 合成工作线程启动: {threading.current_thread().name}")
        
        while self.is_running:
            try:
                # 获取文本块
                priority, chunk_id, text_chunk = self.text_queue.get(timeout=0.1)
                
                # 执行合成
                self._synthesize_chunk(text_chunk)
                
            except queue.Empty:
                continue
            except Exception as e:
                logger.error(f"❌ 合成工作线程错误: {e}")
                self._trigger_callbacks("on_error", {"error": str(e)})
    
    def _synthesize_chunk(self, text_chunk: TextChunk):
        """合成文本块"""
        start_time = time.time()
        
        try:
            logger.debug(f" 开始合成: {text_chunk.text[:30]}...")
            
            # 进行TTS合成
            audio_data = self.tts_manager.synthesize(text_chunk.text)
            
            if audio_data:
                # 转换音频数据
                audio_array = self._convert_audio_data(audio_data)
                
                # 创建音频块
                audio_chunk = AudioChunk(
                    chunk_id=text_chunk.chunk_id,
                    audio_data=audio_array,
                    sample_rate=self.config.sample_rate,
                    timestamp=time.time(),
                    duration=len(audio_array) / self.config.sample_rate,
                    is_final=text_chunk.is_final
                )
                
                # 添加到音频缓冲区
                self.audio_buffer.add_audio(audio_chunk)
                
                # 更新统计
                synthesis_time = time.time() - start_time
                self.stats["synthesis_times"].append(synthesis_time)
                self.stats["total_audio_chunks"] += 1
                
                if self.stats["synthesis_times"]:
                    self.stats["avg_synthesis_time"] = np.mean(list(self.stats["synthesis_times"]))
                
                # 记录首次音频延迟
                if self.stats["first_audio_latency"] is None:
                    self.stats["first_audio_latency"] = synthesis_time
                
                self._trigger_callbacks("on_audio_ready", audio_chunk)
                
                logger.debug(f"✅ 合成完成: {text_chunk.chunk_id}, 耗时: {synthesis_time:.3f}s")
            
        except Exception as e:
            logger.error(f"❌ 文本块合成失败: {e}")
            self._trigger_callbacks("on_error", {
                "error": str(e),
                "chunk_id": text_chunk.chunk_id,
                "text": text_chunk.text
            })
    
    def _convert_audio_data(self, audio_data: bytes) -> np.ndarray:
        """转换音频数据格式"""
        try:
            # 保存到临时文件
            with tempfile.NamedTemporaryFile(suffix='.mp3', delete=False) as temp_file:
                temp_file.write(audio_data)
                temp_path = temp_file.name
            
            # 使用pygame加载音频
            sound = pygame.mixer.Sound(temp_path)
            
            # 获取原始音频数据
            raw_data = pygame.sndarray.array(sound)
            
            # 转换为float32格式
            if raw_data.dtype == np.int16:
                audio_array = raw_data.astype(np.float32) / 32767.0
            else:
                audio_array = raw_data.astype(np.float32)
            
            # 确保是单声道
            if len(audio_array.shape) > 1:
                audio_array = np.mean(audio_array, axis=1)
            
            # 清理临时文件
            os.unlink(temp_path)
            
            return audio_array
            
        except Exception as e:
            logger.error(f"❌ 音频数据转换失败: {e}")
            return np.array([])
    
    def _playback_worker(self):
        """播放工作线程"""
        logger.debug(" 播放工作线程启动")
        
        self._trigger_callbacks("on_playback_start", None)
        
        while self.is_running:
            try:
                # 获取音频块
                audio_chunk = self.audio_buffer.get_next_audio()
                
                if audio_chunk:
                    # 播放音频
                    self._play_audio_chunk(audio_chunk)
                    
                    if audio_chunk.is_final:
                        self._trigger_callbacks("on_synthesis_complete", None)
                else:
                    time.sleep(0.01)  # 没有音频时短暂等待
                    
            except Exception as e:
                logger.error(f"❌ 播放工作线程错误: {e}")
                time.sleep(0.1)
        
        self._trigger_callbacks("on_playback_complete", None)
    
    def _play_audio_chunk(self, audio_chunk: AudioChunk):
        """播放音频块"""
        try:
            # 转换为int16格式
            audio_int16 = (audio_chunk.audio_data * 32767).astype(np.int16)
            
            # 创建pygame Sound对象
            sound = pygame.sndarray.make_sound(audio_int16)
            
            # 播放音频
            sound.play()
            
            # 等待播放完成
            while pygame.mixer.get_busy():
                time.sleep(0.01)
            
            logger.debug(f" 播放完成: chunk {audio_chunk.chunk_id}")
            
        except Exception as e:
            logger.error(f"❌ 音频播放失败: {e}")
    
    def get_synthesis_status(self) -> Dict[str, Any]:
        """获取合成状态"""
        status = {
            "is_running": self.is_running,
            "text_queue_size": self.text_queue.qsize(),
            "buffered_text": self.text_processor.get_buffered_text(),
            "audio_buffer": self.audio_buffer.get_buffer_status(),
            "stats": self.stats.copy()
        }
        
        return status
    
    def clear_queues(self):
        """清空所有队列"""
        # 清空文本队列
        while not self.text_queue.empty():
            try:
                self.text_queue.get_nowait()
            except queue.Empty:
                break
        
        # 清空音频缓冲区
        self.audio_buffer.clear_buffer()
        
        # 清空文本处理器缓冲区
        self.text_processor.clear_buffer()
        
        logger.info("️  所有队列已清空")

# 使用示例和测试代码
if __name__ == "__main__":
    print(" 流式TTS合成器测试")
    print("=" * 50)
    
    # 创建TTS管理器
    from tts_module import TTSManager
    tts_manager = TTSManager()
    
    # 创建流式合成器
    streaming_tts = StreamingTTSSynthesizer(tts_manager)
    
    # 添加回调函数
    def on_audio_ready(audio_chunk):
        print(f" 音频就绪: chunk {audio_chunk.chunk_id}, 时长: {audio_chunk.duration:.2f}s")
    
    def on_synthesis_complete():
        print("✅ 合成完成")
    
    def on_playback_start():
        print(" 开始播放")
    
    def on_playback_complete():
        print(" 播放完成")
    
    def on_error(error_data):
        print(f"❌ 错误: {error_data}")
    
    streaming_tts.add_callback("on_audio_ready", on_audio_ready)
    streaming_tts.add_callback("on_synthesis_complete", on_synthesis_complete)
    streaming_tts.add_callback("on_playback_start", on_playback_start)
    streaming_tts.add_callback("on_playback_complete", on_playback_complete)
    streaming_tts.add_callback("on_error", on_error)
    
    try:
        # 启动合成器
        streaming_tts.start_synthesis()
        
        # 测试文本
        test_texts = [
            "你好,我是AI语音助手。",
            "今天天气很好,适合出门散步。",
            "人工智能技术正在快速发展,",
            "为我们的生活带来了很多便利。",
            "感谢您的使用,再见!"
        ]
        
        # 模拟流式文本输入
        print(" 开始流式文本输入...")
        for i, text in enumerate(test_texts):
            print(f"输入文本 {i+1}: {text}")
            is_final = (i == len(test_texts) - 1)
            streaming_tts.add_text(text, is_final=is_final)
            time.sleep(1)  # 模拟流式输入间隔
        
        # 等待处理完成
        print("⏳ 等待处理完成...")
        time.sleep(10)
        
        # 显示状态
        status = streaming_tts.get_synthesis_status()
        print(f"\n 合成状态:")
        print(f"  文本队列大小: {status['text_queue_size']}")
        print(f"  缓冲文本: '{status['buffered_text']}'")
        print(f"  音频缓冲: {status['audio_buffer']}")
        
        stats = status['stats']
        print(f"  总文本块: {stats['total_text_chunks']}")
        print(f"  总音频块: {stats['total_audio_chunks']}")
        print(f"  平均合成时间: {stats['avg_synthesis_time']:.3f}s")
        print(f"  首次音频延迟: {stats['first_audio_latency']:.3f}s")
        
    except Exception as e:
        print(f"❌ 测试失败: {e}")
        import traceback
        traceback.print_exc()
    finally:
        streaming_tts.stop_synthesis()
    
    print(" 流式TTS合成器测试完成!")

怎么样今天的内容还满意吗?再次感谢朋友们的观看,关注GZH:凡人的AI工具箱,回复666,送您价值199的AI大礼包。最后,祝您早日实现财务自由,还请给个赞,谢谢!

你可能感兴趣的:(深度学习,pytorch,学习,人工智能,自动化,python)