Inference - 控制输出质量
"推理通过策略性的 token 选择,将概率分布转化为连贯的文本。"
训练构建了模型,但推理决定了它的输出。理解解码策略、采样参数和优化技术对于在生产环境中控制模型行为至关重要。本文档涵盖自回归生成、解码算法、采样参数以及大规模部署 LLM 的性能优化技术。
自回归生成
生成循环
LLM 以自回归方式生成文本——一次一个 token,每个新 token 以所有先前 token 为条件。
基本实现
import torch
import torch.nn.functional as F
def generate_autoregressive(
model,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
temperature: float = 1.0,
top_k: int = None,
top_p: float = 1.0,
eos_token_id: int = None
) -> torch.Tensor:
"""
自回归生成 token。
Args:
model: 语言模型
input_ids: 输入 token ID (batch, seq_len)
max_new_tokens: 最大生成 token 数
temperature: 采样温度
top_k: 保留前 k 个 token
top_p: 核采样阈值
eos_token_id: 序列结束 token ID
"""
batch_size, seq_len = input_ids.shape
current_ids = input_ids.clone()
for step in range(max_new_tokens):
# 前向传播
with torch.no_grad():
outputs = model(current_ids)
logits = outputs.logits[:, -1, :] # (batch, vocab_size)
# 应用温度
logits = logits / temperature
# 应用 top-k 过滤
if top_k is not None:
top_k_logits, top_k_indices = torch.topk(logits, top_k)
logits = torch.full_like(logits, float('-inf'))
logits.scatter_(1, top_k_indices, top_k_logits)
# 应用 top-p(核采样)过滤
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# 移除累积概率超过阈值的 token
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
# 采样 token
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # (batch, 1)
# 追加到序列
current_ids = torch.cat([current_ids, next_token], dim=1)
# 检查 EOS
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return current_ids
# 使用示例
input_text = "The future of AI is"
input_ids = tokenizer.encode(input_text, return_tensors="pt")
output_ids = generate_autoregressive(model, input_ids, max_new_tokens=50, temperature=0.8)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text)
解码策略
策略对比
| 策略 | 描述 | 速度 | 质量 | 多样性 | 适用场景 |
|---|---|---|---|---|---|
| 贪心搜索 | 始终选择最高概率 | 最快 | 适合事实性任务 | 无 | 事实问答、代码 |
| 束搜索 | 保留前 k 个假设 | 慢 | 高质量 | 低 | 翻译、摘要 |
| 采样 | 从概率分布中采样 | 快 | 可变 | 高 | 创意写作 |
| 核采样 (Top-p) | 从最小顶部概率质量中采样 | 快 | 好 | 高 | 通用聊天、助手 |
| Top-k | 从前 k 个 token 中采样 | 快 | 好 | 中高 | 平衡生成 |
| 对比搜索 | 平衡概率 + 退化惩罚 | 中 | 很高 | 中 | 长文本内容 |
| MCTS | 带前瞻的树搜索 | 很慢 | 最佳 | 中 | 复杂推理 |
贪心搜索
def greedy_search(model, input_ids: torch.Tensor, max_new_tokens: int) -> torch.Tensor:
"""
贪心解码:始终选择最可能的 token。
"""
current_ids = input_ids.clone()
for _ in range(max_new_tokens):
outputs = model(current_ids)
next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
current_ids = torch.cat([current_ids, next_token], dim=1)
return current_ids
# 示例:贪心搜索可能重复 "The the the the..." 如果陷入循环
束搜索
def beam_search(
model,
input_ids: torch.Tensor,
max_new_tokens: int,
num_beams: int = 4,
length_penalty: float = 1.0
) -> torch.Tensor:
"""
束搜索:在每一步保留前 k 个假设。
Args:
model: 语言模型
input_ids: 输入 token ID
max_new_tokens: 最大生成 token 数
num_beams: 跟踪的束数量
length_penalty: 惩罚短序列(1.0 = 无惩罚)
"""
batch_size = input_ids.shape[0]
input_ids = input_ids.repeat_interleave(num_beams, dim=0)
# 初始化束
beam_scores = torch.zeros(batch_size * num_beams, device=input_ids.device)
beam_scores[1::num_beams] = float('-inf') # 只有第一个束有效
for step in range(max_new_tokens):
outputs = model(input_ids)
next_token_logits = outputs.logits[:, -1, :]
next_token_scores = F.log_softmax(next_token_logits, dim=-1)
# 添加束分数
vocab_size = next_token_scores.shape[-1]
next_scores = beam_scores.unsqueeze(-1) + next_token_scores
# 重塑以进行 top-k 选择
next_scores = next_scores.view(batch_size, num_beams * vocab_size)
next_scores, next_tokens = torch.topk(next_scores, k=num_beams, dim=-1)
# 将扁平索引转换为 (束, token) 对
beam_indices = next_tokens // vocab_size
token_indices = next_tokens % vocab_size
# 更新束
input_ids = input_ids.view(batch_size, num_beams, -1)
input_ids = input_ids[torch.arange(batch_size).unsqueeze(-1), beam_indices]
input_ids = input_ids.reshape(batch_size * num_beams, -1)
next_token_ids = token_indices.view(-1, 1)
input_ids = torch.cat([input_ids, next_token_ids], dim=-1)
beam_scores = next_scores.view(-1)
# 选择最佳束
input_ids = input_ids.view(batch_size, num_beams, -1)
best_beam_indices = beam_scores.view(batch_size, num_beams).argmax(dim=-1)
output_ids = input_ids[torch.arange(batch_size), best_beam_indices]
return output_ids
# 使用示例
output = beam_search(model, input_ids, max_new_tokens=50, num_beams=4)
对比搜索
对比搜索将概率建模与退化惩罚相结合:
def contrastive_search(
model,
input_ids: torch.Tensor,
max_new_tokens: int,
top_k: int = 5,
alpha: float = 0.6,
momentum: float = 0.5
) -> torch.Tensor:
"""
对比搜索:平衡概率和与上下文的相似度。
Args:
model: 语言模型
input_ids: 输入 token
max_new_tokens: 最大生成 token 数
top_k: 候选 token 数量
alpha: 重复惩罚(0 = 纯采样,1 = 纯退化惩罚)
momentum: 更新累积概率的权重
参考:
Su, J., et al. (2022). "A Contrastive Framework for Neural Text Generation"
https://arxiv.org/abs/2202.01855
"""
current_ids = input_ids.clone()
cumulative_probs = None
for _ in range(max_new_tokens):
# 前向传播
with torch.no_grad():
outputs = model(current_ids)
logits = outputs.logits[:, -1, :]
hidden_states = outputs.hidden_states[-1][:, -1, :] # 最后一层隐藏状态
# 获取 top-k 候选
top_k_logits, top_k_indices = torch.topk(logits, top_k)
top_k_probs = F.softmax(top_k_logits, dim=-1)
# 计算模型置信度(概率质量)
model_confidence = top_k_probs.max(dim=-1)[0]
# 计算退化惩罚(与之前 token 的相似度)
# 使用当前隐藏状态和之前状态的余弦相似度
prev_hidden = outputs.hidden_states[-1][:, :-1, :] # 所有之前的隐藏状态
# 与最近的 token 计算相似度
if prev_hidden.size(1) > 0:
# 与最后几个 token 的相似度(局部上下文)
recent_hidden = prev_hidden[:, -min(5, prev_hidden.size(1)):, :]
similarities = F.cosine_similarity(
hidden_states.unsqueeze(1),
recent_hidden,
dim=-1
).max(dim=1)[0]
degeneration_penalty = similarities.max()
else:
degeneration_penalty = torch.tensor(0.0)
# 更新累积概率
if cumulative_probs is None:
cumulative_probs = model_confidence
else:
cumulative_probs = momentum * cumulative_probs + (1 - momentum) * model_confidence
# 选择最大化以下目标的 token:(1-alpha) * 概率 - alpha * 退化
scores = (1 - alpha) * top_k_probs - alpha * degeneration_penalty
# 选择最佳 token
best_idx = scores.argmax(dim=-1)
next_token = top_k_indices[range(len(best_idx)), best_idx].unsqueeze(-1)
current_ids = torch.cat([current_ids, next_token], dim=1)
return current_ids
MCTS(蒙特卡洛树搜索)解码
适用于需要更深层前瞻的任务:
class MCTSNode:
"""蒙特卡洛树搜索解码的节点。"""
def __init__(self, token_id: int, parent=None):
self.token_id = token_id
self.parent = parent
self.children = []
self.visits = 0
self.total_value = 0.0
self.prior_prob = 0.0
def ucb_score(self, c_puct: float = 1.0) -> float:
"""选择的上置信界。"""
if self.visits == 0:
return float('inf')
exploitation = self.total_value / self.visits
exploration = c_puct * math.sqrt(math.log(self.parent.visits) / self.visits) if self.parent else 0
return exploitation + exploration
def mcts_decode(
model,
input_ids: torch.Tensor,
max_new_tokens: int,
num_simulations: int = 50,
c_puct: float = 1.0,
temperature: float = 1.0
) -> torch.Tensor:
"""
用于前瞻解码的蒙特卡洛树搜索。
代价更高但可以为复杂任务找到更好的序列。
"""
current_ids = input_ids.clone()
for step in range(max_new_tokens):
# 构建搜索树
root = MCTSNode(token_id=None)
# 运行模拟
for _ in range(num_simulations):
# 选择:使用 UCB 遍历树
node = root
path = []
while node.children:
# 选择 UCB 分数最高的子节点
node = max(node.children, key=lambda n: n.ucb_score(c_puct))
path.append(node)
# 扩展:如果不是终止节点则添加新子节点
if len(path) < 5: # 限制深度
# 获取模型预测
with torch.no_grad():
outputs = model(current_ids)
logits = outputs.logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
# 采样候选 token
token = torch.multinomial(probs, num_samples=1).item()
child = MCTSNode(token, parent=node)
node.children.append(child)
node.prior_prob = probs[0, token].item()
path.append(child)
# 评估:使用模型估计价值
with torch.no_grad():
# 构建完整序列进行评估
eval_ids = current_ids.clone()
for node in path[1:]:
eval_ids = torch.cat([eval_ids, torch.tensor([[node.token_id]])], dim=1)
outputs = model(eval_ids)
# 价值 = 负损失(越高越好)
value = -outputs.loss.item() if hasattr(outputs, 'loss') else 0.0
# 反向传播:更新价值
for node in path:
node.visits += 1
node.total_value += value
# 模拟后选择最佳子节点
if root.children:
best_child = max(root.children, key=lambda n: n.visits)
next_token = torch.tensor([[best_child.token_id]])
current_ids = torch.cat([current_ids, next_token], dim=1)
else:
break
return current_ids
带温度的采样
def sample_with_temperature(
model,
input_ids: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0
) -> torch.Tensor:
"""
带温度缩放的概率分布采样。
"""
current_ids = input_ids.clone()
for _ in range(max_new_tokens):
outputs = model(current_ids)
logits = outputs.logits[:, -1, :] / temperature
# 从 softmax 分布采样
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
current_ids = torch.cat([current_ids, next_token], dim=1)
return current_ids
# 温度效果
# T=0.1: 非常确定性,几乎等于贪心
# T=0.5: 集中,主要是高概率 token
# T=1.0: 标准采样
# T=1.5: 更有创意,包含低概率 token
# T=2.0+: 非常随机,通常不连贯
采样参数
温度
控制采样中的随机性:
def apply_temperature(logits: torch.Tensor, temperature: float) -> torch.Tensor:
"""
对 logits 应用温度缩放。
较低温度 -> 更尖锐的分布
较高温度 -> 更平坦的分布
"""
return logits / temperature
# 温度对概率分布的影响
logits = torch.tensor([2.0, 1.0, 0.0, -1.0, -2.0])
print("温度效果:")
for temp in [0.1, 0.5, 1.0, 2.0]:
scaled = apply_temperature(logits, temp)
probs = F.softmax(scaled, dim=0)
print(f" T={temp:.1f}: {probs.tolist()}")
输出:
温度效果:
T=0.1: [0.97, 0.03, 0.00, 0.00, 0.00] # 非常尖锐
T=0.5: [0.67, 0.24, 0.07, 0.01, 0.00] # 集中
T=1.0: [0.50, 0.27, 0.12, 0.07, 0.04] # 平衡
T=2.0: [0.39, 0.32, 0.16, 0.09, 0.05] # 平坦
Top-K vs Top-P(核采样)
def apply_top_k(logits: torch.Tensor, top_k: int) -> torch.Tensor:
"""
过滤保留前 k 个 token。
"""
top_k_logits, top_k_indices = torch.topk(logits, top_k)
filtered = torch.full_like(logits, float('-inf'))
filtered.scatter_(0, top_k_indices, top_k_logits)
return filtered
def apply_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
"""
核采样:保留累积概率质量 >= top_p 的最小顶部集合。
"""
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# 找到要移除的索引
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = False
# 散射回原始顺序
indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
return logits
# 示例:Top-k vs Top-p
logits = torch.randn(50000) # 大词汇表
# Top-k:始终保留恰好 k 个 token
top_k_filtered = apply_top_k(logits, top_k=50)
# Top-p:保留可变数量的 token
top_p_filtered = apply_top_p(logits, top_p=0.9)
num_kept = (top_p_filtered != float('-inf')).sum()
print(f"Top-p=0.9 保留了 {num_kept} 个 token")
频率和存在惩罚
def apply_frequency_penalty(
logits: torch.Tensor,
token_ids: torch.Tensor,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0
) -> torch.Tensor:
"""
应用频率和存在惩罚。
Args:
logits: 模型 logits (vocab_size,)
token_ids: 之前生成的 token ID
frequency_penalty: 基于频率的惩罚(越高 = 重复越少)
presence_penalty: 基于存在的惩罚(二值)
"""
# 统计 token 频率
unique_tokens, counts = torch.unique(token_ids, return_counts=True)
# 应用惩罚
for token, count in zip(unique_tokens, counts):
# 频率惩罚:与计数成正比
logits[token] -= frequency_penalty * count
# 存在惩罚:二值(存在或不存在)
logits[token] -= presence_penalty
return logits
# 示例
generated_tokens = torch.tensor([10, 20, 10, 30, 10, 20]) # 10: 3次, 20: 2次, 30: 1次
logits = torch.randn(50000)
# 带惩罚时,token 10 和 20 被更重地惩罚
logits_penalized = apply_frequency_penalty(
logits.clone(),
generated_tokens,
frequency_penalty=0.5,
presence_penalty=0.1
)
参数参考表
| 参数 | 范围 | 效果 | 使用场景 |
|---|---|---|---|
| temperature | 0.0 - 2.0 | 随机性 | 0.2: 编码、事实 / 0.8: 聊天 / 1.2: 创意 |
| top_k | 1 - 100 | 多样性 | 1: 贪心 / 40-50: 平衡 / 100: 非常多样 |
| top_p | 0.1 - 1.0 | 质量过滤 | 0.5: 集中 / 0.9: 标准 / 1.0: 无过滤 |
| frequency_penalty | 0.0 - 2.0 | 减少重复 | 0.0: 无 / 0.5: 适度 / 1.0: 强 |
| presence_penalty | 0.0 - 2.0 | 鼓励多样性 | 0.0: 无 / 0.5: 适度 / 1.0: 强 |