导航菜单

递归与回溯

递归和回溯是解决复杂问题的重要方法。递归通过将问题分解为更小的子问题来解决,回溯则通过尝试所有可能的解来找到答案。掌握这两种方法对于解决许多算法问题至关重要。

递归基础

递归的三个要素

  1. 递归终止条件:避免无限递归
  2. 递归关系:如何将问题分解为子问题
  3. 递归调用:调用自身解决子问题

递归示例

1. 阶乘

def factorial(n):
    # 终止条件
    if n <= 1:
        return 1
    # 递归关系:n! = n * (n-1)!
    return n * factorial(n - 1)

2. 斐波那契数列

def fibonacci(n):
    # 终止条件
    if n <= 1:
        return n
    # 递归关系:F(n) = F(n-1) + F(n-2)
    return fibonacci(n - 1) + fibonacci(n - 2)

3. 二叉树遍历

def inorder_traversal(root):
    if not root:
        return []
    return inorder_traversal(root.left) + [root.val] + inorder_traversal(root.right)

递归优化

1. 记忆化(Memoization)

from functools import lru_cache

@lru_cache(maxsize=None)
def fibonacci_memo(n):
    if n <= 1:
        return n
    return fibonacci_memo(n - 1) + fibonacci_memo(n - 2)

# 手动记忆化
def fibonacci_manual(n, memo={}):
    if n in memo:
        return memo[n]
    if n <= 1:
        return n
    memo[n] = fibonacci_manual(n - 1, memo) + fibonacci_manual(n - 2, memo)
    return memo[n]

2. 尾递归优化

def factorial_tail(n, acc=1):
    if n <= 1:
        return acc
    return factorial_tail(n - 1, n * acc)

回溯算法

回溯是一种通过尝试所有可能的解来找到答案的算法。当发现当前路径不可能得到解时,会”回溯”到上一步,尝试其他路径。

回溯算法的模板

def backtrack(path, choices):
    # 终止条件
    if is_solution(path):
        result.append(path[:])  # 注意:需要复制
        return
    
    # 遍历所有选择
    for choice in choices:
        # 做选择
        if is_valid(choice):
            path.append(choice)
            # 递归
            backtrack(path, get_next_choices(choice))
            # 撤销选择(回溯)
            path.pop()

经典回溯问题

1. 全排列

def permute(nums):
    result = []
    
    def backtrack(path, remaining):
        if len(path) == len(nums):
            result.append(path[:])
            return
        
        for i in range(len(remaining)):
            path.append(remaining[i])
            backtrack(path, remaining[:i] + remaining[i+1:])
            path.pop()
    
    backtrack([], nums)
    return result

2. 组合

def combine(n, k):
    result = []
    
    def backtrack(path, start):
        if len(path) == k:
            result.append(path[:])
            return
        
        for i in range(start, n + 1):
            path.append(i)
            backtrack(path, i + 1)
            path.pop()
    
    backtrack([], 1)
    return result

3. N 皇后问题

def solve_n_queens(n):
    result = []
    board = ['.' * n for _ in range(n)]
    
    def is_valid(row, col):
        # 检查列
        for i in range(row):
            if board[i][col] == 'Q':
                return False
        
        # 检查左上对角线
        i, j = row - 1, col - 1
        while i >= 0 and j >= 0:
            if board[i][j] == 'Q':
                return False
            i -= 1
            j -= 1
        
        # 检查右上对角线
        i, j = row - 1, col + 1
        while i >= 0 and j < n:
            if board[i][j] == 'Q':
                return False
            i -= 1
            j += 1
        
        return True
    
    def backtrack(row):
        if row == n:
            result.append(board[:])
            return
        
        for col in range(n):
            if is_valid(row, col):
                board[row] = board[row][:col] + 'Q' + board[row][col+1:]
                backtrack(row + 1)
                board[row] = board[row][:col] + '.' + board[row][col+1:]
    
    backtrack(0)
    return result

4. 单词搜索

def exist(board, word):
    rows, cols = len(board), len(board[0])
    
    def backtrack(row, col, index):
        if index == len(word):
            return True
        
        if row < 0 or row >= rows or col < 0 or col >= cols:
            return False
        
        if board[row][col] != word[index]:
            return False
        
        # 标记为已访问
        temp = board[row][col]
        board[row][col] = '#'
        
        # 四个方向
        directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]
        for dr, dc in directions:
            if backtrack(row + dr, col + dc, index + 1):
                return True
        
        # 回溯
        board[row][col] = temp
        return False
    
    for i in range(rows):
        for j in range(cols):
            if backtrack(i, j, 0):
                return True
    
    return False

5. 子集

def subsets(nums):
    result = []
    
    def backtrack(path, start):
        result.append(path[:])
        
        for i in range(start, len(nums)):
            path.append(nums[i])
            backtrack(path, i + 1)
            path.pop()
    
    backtrack([], 0)
    return result

回溯优化技巧

1. 剪枝

def combination_sum(candidates, target):
    result = []
    candidates.sort()  # 排序以便剪枝
    
    def backtrack(path, start, remaining):
        if remaining == 0:
            result.append(path[:])
            return
        
        for i in range(start, len(candidates)):
            # 剪枝:如果当前数字已经大于剩余值,后面的数字更大,不可能满足
            if candidates[i] > remaining:
                break
            
            path.append(candidates[i])
            backtrack(path, i, remaining - candidates[i])
            path.pop()
    
    backtrack([], 0, target)
    return result

2. 使用集合去重

def permute_unique(nums):
    result = []
    nums.sort()
    
    def backtrack(path, used):
        if len(path) == len(nums):
            result.append(path[:])
            return
        
        for i in range(len(nums)):
            if used[i]:
                continue
            # 剪枝:如果当前数字与前一个相同,且前一个未使用,跳过
            if i > 0 and nums[i] == nums[i-1] and not used[i-1]:
                continue
            
            used[i] = True
            path.append(nums[i])
            backtrack(path, used)
            path.pop()
            used[i] = False
    
    backtrack([], [False] * len(nums))
    return result

常见题目

1. 电话号码的字母组合

def letter_combinations(digits):
    if not digits:
        return []
    
    mapping = {
        '2': 'abc', '3': 'def', '4': 'ghi', '5': 'jkl',
        '6': 'mno', '7': 'pqrs', '8': 'tuv', '9': 'wxyz'
    }
    
    result = []
    
    def backtrack(path, index):
        if index == len(digits):
            result.append(''.join(path))
            return
        
        for char in mapping[digits[index]]:
            path.append(char)
            backtrack(path, index + 1)
            path.pop()
    
    backtrack([], 0)
    return result

2. 括号生成

def generate_parenthesis(n):
    result = []
    
    def backtrack(path, open_count, close_count):
        if len(path) == 2 * n:
            result.append(''.join(path))
            return
        
        if open_count < n:
            path.append('(')
            backtrack(path, open_count + 1, close_count)
            path.pop()
        
        if close_count < open_count:
            path.append(')')
            backtrack(path, open_count, close_count + 1)
            path.pop()
    
    backtrack([], 0, 0)
    return result

总结

递归和回溯是解决复杂问题的重要方法:

  1. 递归:将问题分解为子问题,需要明确的终止条件和递归关系
  2. 回溯:尝试所有可能的解,通过剪枝优化效率
  3. 优化技巧:记忆化、剪枝、去重
  4. 常见问题:排列、组合、N 皇后、单词搜索、子集

掌握递归和回溯能够帮助你解决许多复杂的算法问题。


接下来,让我们学习动态规划,这是解决最优化问题的重要方法。

搜索