请求批处理与队列
LLM 推理的 GPU 利用率受 batch size 影响极大。合理的批处理和请求队列设计可以将吞吐提升 3-10 倍。
批处理策略
graph TB
A[批处理策略] --> B[静态批处理]
A --> C[动态批处理]
A --> D[连续批处理]
B --> B1[固定 batch size
等待凑齐后处理] C --> C1[可变 batch size
按超时/数量触发] D --> D1[Continuous Batching
请求随到随处理] style A fill:#e3f2fd,stroke:#1976d2,stroke-width:3px style D fill:#e8f5e9,stroke:#388e3c,stroke-width:2px
等待凑齐后处理] C --> C1[可变 batch size
按超时/数量触发] D --> D1[Continuous Batching
请求随到随处理] style A fill:#e3f2fd,stroke:#1976d2,stroke-width:3px style D fill:#e8f5e9,stroke:#388e3c,stroke-width:2px
策略对比
| 策略 | 吞吐 | 延迟 | 实现复杂度 | 适用场景 |
|---|---|---|---|---|
| 静态批处理 | 中 | 高(等待) | 低 | 离线任务 |
| 动态批处理 | 高 | 中 | 中 | 在线推理 |
| 连续批处理 | 最高 | 最低 | 高 | 高并发实时 |
动态批处理器
"""
动态批处理请求队列
"""
import time
import threading
from dataclasses import dataclass, field
from queue import Queue, Empty
from typing import Any
@dataclass
class InferenceRequest:
"""推理请求"""
request_id: str
prompt: str
max_tokens: int = 512
priority: int = 0 # 越大越优先
created_at: float = field(default_factory=time.time)
result: Any = None
event: threading.Event = field(default_factory=threading.Event)
@dataclass
class BatchConfig:
"""批处理配置"""
max_batch_size: int = 32
max_wait_ms: float = 50.0 # 最大等待时间
max_total_tokens: int = 8192 # 批次最大 token 数
priority_classes: int = 3 # 优先级等级数
class DynamicBatcher:
"""动态批处理器"""
def __init__(self, config: BatchConfig, inference_fn=None):
self.config = config
self.inference_fn = inference_fn
self._queues: list[Queue] = [
Queue() for _ in range(config.priority_classes)
]
self._running = False
self._stats = {"batches": 0, "total_requests": 0, "avg_batch_size": 0}
def submit(self, request: InferenceRequest) -> InferenceRequest:
"""提交请求"""
priority = min(request.priority, self.config.priority_classes - 1)
self._queues[priority].put(request)
return request
def wait_for_result(self, request: InferenceRequest, timeout: float = 30.0) -> Any:
"""等待结果"""
request.event.wait(timeout=timeout)
return request.result
def start(self):
"""启动批处理循环"""
self._running = True
thread = threading.Thread(target=self._batch_loop, daemon=True)
thread.start()
def stop(self):
self._running = False
def _batch_loop(self):
"""批处理主循环"""
while self._running:
batch = self._collect_batch()
if batch:
self._process_batch(batch)
else:
time.sleep(0.001) # 短暂休眠避免空转
def _collect_batch(self) -> list[InferenceRequest]:
"""收集一个批次"""
batch: list[InferenceRequest] = []
total_tokens = 0
deadline = time.time() + self.config.max_wait_ms / 1000
while (
len(batch) < self.config.max_batch_size
and total_tokens < self.config.max_total_tokens
):
# 按优先级从高到低取
request = None
for q in self._queues:
try:
remaining = max(0, deadline - time.time())
request = q.get(timeout=remaining if not batch else 0.001)
break
except Empty:
continue
if request is None:
if batch:
break # 超时,处理已有的
if time.time() >= deadline:
break
continue
batch.append(request)
total_tokens += request.max_tokens
return batch
def _process_batch(self, batch: list[InferenceRequest]):
"""处理一个批次"""
prompts = [r.prompt for r in batch]
if self.inference_fn:
results = self.inference_fn(prompts)
else:
results = [f"[batch-response] {p[:50]}..." for p in prompts]
for req, result in zip(batch, results):
req.result = result
req.event.set()
# 更新统计
self._stats["batches"] += 1
self._stats["total_requests"] += len(batch)
self._stats["avg_batch_size"] = (
self._stats["total_requests"] / self._stats["batches"]
)
def get_stats(self) -> dict:
queue_sizes = [q.qsize() for q in self._queues]
return {
**self._stats,
"queue_sizes": queue_sizes,
"total_queued": sum(queue_sizes),
}
优先级队列设计
"""
多优先级请求队列
"""
from dataclasses import dataclass, field
import heapq
import time
@dataclass(order=True)
class PrioritizedRequest:
"""带优先级的请求"""
priority: int # 负数 = 高优先
timestamp: float # 入队时间(保证 FIFO)
request: Any = field(compare=False)
class PriorityRequestQueue:
"""优先级请求队列"""
# 优先级映射
PRIORITY_MAP = {
"realtime": -3, # 实时(付费用户)
"normal": -2, # 普通
"batch": -1, # 批量(低优先)
}
def __init__(self, max_size: int = 10000):
self._heap: list[PrioritizedRequest] = []
self.max_size = max_size
self._dropped = 0
def enqueue(self, request: Any, priority: str = "normal") -> bool:
"""入队"""
if len(self._heap) >= self.max_size:
self._dropped += 1
return False
pri = self.PRIORITY_MAP.get(priority, -2)
item = PrioritizedRequest(
priority=pri,
timestamp=time.time(),
request=request,
)
heapq.heappush(self._heap, item)
return True
def dequeue(self) -> Any | None:
"""出队"""
if self._heap:
return heapq.heappop(self._heap).request
return None
def dequeue_batch(self, batch_size: int) -> list:
"""批量出队"""
batch = []
for _ in range(min(batch_size, len(self._heap))):
batch.append(heapq.heappop(self._heap).request)
return batch
@property
def size(self) -> int:
return len(self._heap)
def stats(self) -> dict:
return {
"queue_size": self.size,
"dropped": self._dropped,
"max_size": self.max_size,
}
批处理参数调优
graph LR
A[调优目标] --> B{延迟优先?}
B -->|是| C[小 batch + 短超时
batch=8, wait=10ms] B -->|否| D{吞吐优先?} D -->|是| E[大 batch + 长超时
batch=64, wait=100ms] D -->|否| F[平衡模式
batch=32, wait=50ms] style A fill:#e3f2fd,stroke:#1976d2,stroke-width:3px
batch=8, wait=10ms] B -->|否| D{吞吐优先?} D -->|是| E[大 batch + 长超时
batch=64, wait=100ms] D -->|否| F[平衡模式
batch=32, wait=50ms] style A fill:#e3f2fd,stroke:#1976d2,stroke-width:3px
本章小结
| 主题 | 要点 |
|---|---|
| 连续批处理 | 最优策略,vLLM 等框架内置支持 |
| 动态批处理 | batch size + 超时双触发 |
| 优先级队列 | 付费用户/实时请求优先处理 |
| 调优方向 | 延迟优先 → 小 batch;吞吐优先 → 大 batch |
下一章:监控与可观测性