回溯算法
为什么回溯算法很重要
回溯算法通过递增地构建解,并放弃那 些无法导向有效解的局部解,从而系统地探索所有可能性:
- 组合问题:生成所有的排列或组合。
- 约束满足问题:数独、N 皇后问题、填字游戏。
- 路径搜索:迷宫求解、图遍历。
- 最优化问题:在众多可能性中寻找最优解。
实际影响:
- 使用回溯法解决 N 皇后问题:N=8 时仅需 2 秒。
- 暴力检查所有位置:N=8 时约需 4 小时。
- 速度提升 7200 倍。
核心概念
回溯模板
void backtrack(state, parameters) {
if (isSolution(state)) {
recordSolution(state);
return;
}
for (choice in generateChoices(state)) {
if (isValid(choice)) {
makeChoice(choice);
backtrack(state, parameters);
undoChoice(choice); // 回溯(撤销选择)
}
}
}
回溯 vs 递归
| 维度 | 回溯 | 递归 |
|---|---|---|
| 搜索空间 | 系统化探索所有可能 | 分而治之 |
| 剪枝 | 激进(提前终止) | 通常不包含剪枝 |
| 用例 | 约束满足问题 | 问题拆解/子问题求解 |
| 状态管理 | 需要手动做出/撤销选择 | 通常不需要状态回滚 |
深入理解
全排列 (Permutations)
生成所有可能的排列方案:
public List<List<Integer>> permute(int[] nums) {
List<List<Integer>> result = new ArrayList<>();
backtrack(result, new ArrayList<>(), new boolean[nums.length], nums);
return result;
}
private void backtrack(List<List<Integer>> result, List<Integer> current,
boolean[] used, int[] nums) {
if (current.size() == nums.length) {
result.add(new ArrayList<>(current));
return;
}
for (int i = 0; i < nums.length; i++) {
if (used[i]) continue;
current.add(nums[i]);
used[i] = true;
backtrack(result, current, used, nums);
current.remove(current.size() - 1); // 回溯
used[i] = false;
}
}
组合 (Combinations)
生成所有包含 k 个元素的组合:
public List<List<Integer>> combine(int n, int k) {
List<List<Integer>> result = new ArrayList<>();
backtrack(result, new ArrayList<>(), 1, n, k);
return result;
}
private void backtrack(List<List<Integer>> result, List<Integer> current,
int start, int n, int k) {
if (current.size() == k) {
result.add(new ArrayList<>(current));
return;
}
for (int i = start; i <= n; i++) {
current.add(i);
backtrack(result, current, i + 1, n, k); // 注意:是 i + 1 而不是 start
current.remove(current.size() - 1); // 回溯
}
}
剪枝策略
在排列中去除重复项
public List<List<Integer>> permuteUnique(int[] nums) {
List<List<Integer>> result = new ArrayList<>();
Arrays.sort(nums); // 排序以便将重复项聚集在一起
backtrack(result, new ArrayList<>(), new boolean[nums.length], nums);
return result;
}
private void backtrack(List<List<Integer>> result, List<Integer> current,
boolean[] used, int[] nums) {
if (current.size() == nums.length) {
result.add(new ArrayList<>(current));
return;
}
for (int i = 0; i < nums.length; i++) {
if (used[i]) continue;
// 跳过重复:仅使用重复数字中的第一个出现项
if (i > 0 && nums[i] == nums[i - 1] && !used[i - 1]) continue;
current.add(nums[i]);
used[i] = true;
backtrack(result, current, used, nums);
current.remove(current.size() - 1);
used[i] = false;
}
}
常见陷阱
❌ 添加结果时未进行深拷贝
if (current.size() == k) {
result.add(current); // 错误:添加的是引用!
}