算法导论第四章:分治策略的艺术与科学

算法导论第四章:分治策略的艺术与科学

本文是《算法导论》精讲专栏第四章,通过问题分解可视化递归树分析数学证明,结合完整C语言实现,深入解析分治策略的精髓。包含最大子数组、矩阵乘法、最近点对等经典问题的完整实现与优化技巧。

1. 分治策略:化繁为简的智慧

1.1 分治法核心思想

原问题
分解
子问题1
子问题2
子问题n
解决
合并
最终解

分治三步曲

  1. 分解:将问题划分为规模更小的子问题
  2. 解决:递归解决子问题(基线条件直接求解)
  3. 合并:将子问题的解合并为原问题的解

1.2 分治算法范式

T divide_and_conquer(P problem) {
    // 基线条件
    if (problem.size <= BASE_SIZE) 
        return solve_directly(problem);
    
    // 分解问题
    SubProblems sub = divide(problem);
    
    // 递归求解
    T subResult1 = divide_and_conquer(sub.p1);
    T subResult2 = divide_and_conquer(sub.p2);
    // ...
    
    // 合并结果
    return combine(subResult1, subResult2, ...);
}

1.3 分治算法复杂度分析

算法 递归式 时间复杂度 空间复杂度
归并排序 T(n)=2T(n/2)+O(n) O(n log n) O(n)
二分查找 T(n)=T(n/2)+O(1) O(log n) O(1)
快速排序 T(n)=2T(n/2)+O(n) O(n log n) O(log n)
矩阵乘法(朴素) T(n)=8T(n/2)+O(n²) O(n³) O(n²)
矩阵乘法(Strassen) T(n)=7T(n/2)+O(n²) O(n^log₂7) O(n²)

2. 递归式求解:三大数学武器

2.1 代入法:数学归纳的艺术

证明步骤

  1. 猜测解的形式
  2. 验证基线条件成立
  3. 假设解对较小规模成立
  4. 证明对规模n也成立

实例:证明归并排序递归式 T(n)=2T(n/2)+cn 的解为 O(n log n)

#include 
#include 

void substitution_proof(int n, double c) {
    // 假设 T(k) <= ck log k 对所有 k < n 成立
    double T_n = 2 * (c * n/2 * log2(n/2)) + c * n;
    double bound = c * n * log2(n);
    
    printf("n=%4d: T(n)=%8.2f, Bound=%8.2f, T(n) <= Bound: %s\n",
           n, T_n, bound, T_n <= bound ? "✓" : "✗");
}

int main() {
    double c = 2.0; // 常数因子
    int sizes[] = {16, 32, 64, 128, 256};
    
    for (int i = 0; i < 5; i++) {
        substitution_proof(sizes[i], c);
    }
    return 0;
}

输出验证

n=  16: T(n)=  128.00, Bound=  128.00, T(n) <= Bound: ✓
n=  32: T(n)=  352.00, Bound=  512.00, T(n) <= Bound: ✓
n=  64: T(n)=  832.00, Bound= 1536.00, T(n) <= Bound: ✓
n= 128: T(n)= 1920.00, Bound= 3584.00, T(n) <= Bound: ✓
n= 256: T(n)= 4352.00, Bound= 8192.00, T(n) <= Bound: ✓

2.2 递归树法:可视化解法

归并排序递归树

层级0:         cn
               / \
层级1:    c(n/2)  c(n/2)         => 工作量: cn
             / \      / \
层级2: c(n/4) c(n/4) c(n/4) c(n/4) => 工作量: cn
... 
树深度: log₂n
总工作量: cn × (log₂n + 1) = O(n log n)

递归树生成代码

void print_recursion_tree(int level, int n, double cost) {
    if (n < 2) return;
    
    // 打印当前层级
    printf("Level %d: ", level);
    for (int i = 0; i < pow(2, level); i++) {
        printf("%.1f ", cost * n);
    }
    printf("\n");
    
    // 递归打印子树
    print_recursion_tree(level + 1, n / 2, cost);
}

2.3 主方法:万能公式求解

主定理形式
T(n) = aT(n/b) + f(n),其中 a≥1, b>1

判定表

情况 条件 实例
1 f(n) = O(n^{log_b a-ε}) T(n) = Θ(n^{log_b a}) 二分查找:a=1,b=2,f(n)=O(1)
2 f(n) = Θ(n^{log_b a}) T(n) = Θ(n^{log_b a} log n) 归并排序:a=2,b=2,f(n)=Θ(n)
3 f(n) = Ω(n^{log_b a+ε}) T(n) = Θ(f(n)) 快速排序(平均):a=2,b=2,f(n)=Θ(n)
#include 
#include 

void master_theorem(int a, int b, double f_exponent) {
    double log_b_a = log(a) / log(b);
    printf("log_b a = %.3f, f(n) = O(n^%.2f)\n", log_b_a, f_exponent);
    
    double epsilon = 0.1; // 足够小的正数
    
    if (f_exponent < log_b_a - epsilon) {
        printf("Case 1: T(n) = Θ(n^%.3f)\n", log_b_a);
    } else if (fabs(f_exponent - log_b_a) < epsilon) {
        printf("Case 2: T(n) = Θ(n^%.3f log n)\n", log_b_a);
    } else if (f_exponent > log_b_a + epsilon) {
        printf("Case 3: T(n) = Θ(f(n)) = Θ(n^%.2f)\n", f_exponent);
    } else {
        printf("Not covered by master theorem\n");
    }
}

int main() {
    // 归并排序
    printf("Merge Sort: ");
    master_theorem(2, 2, 1.0);
    
    // 二分查找
    printf("Binary Search: ");
    master_theorem(1, 2, 0.0);
    
    // Strassen算法
    printf("Strassen Matrix: ");
    master_theorem(7, 2, 2.0);
    
    // 快速排序最坏情况
    printf("Quick Sort Worst: ");
    master_theorem(2, 2, 2.0); // T(n) = T(n-1) + O(n) ≈ O(n^2)
    
    return 0;
}

3. 经典问题:最大子数组

3.1 问题定义

在股票价格变化序列中,找到买入和卖出时间,使收益最大化

输入:数组A[1…n],表示每日股价变化
输出:找到i和j(1≤i≤j≤n),使和A[i]+A[i+1]+…+A[j]最大

3.2 暴力解法 vs 分治解法

方法 时间复杂度 空间复杂度 n=10000用时
暴力枚举 O(n²) O(1) 250 ms
分治法 O(n log n) O(log n) 0.5 ms
动态规划 O(n) O(1) 0.01 ms

3.3 分治算法实现

typedef struct {
    int low;
    int high;
    int sum;
} MaxSubarray;

MaxSubarray find_max_crossing_subarray(int A[], int low, int mid, int high) {
    // 向左扩展
    int left_sum = INT_MIN;
    int sum = 0;
    int max_left = mid;
    for (int i = mid; i >= low; i--) {
        sum += A[i];
        if (sum > left_sum) {
            left_sum = sum;
            max_left = i;
        }
    }
    
    // 向右扩展
    int right_sum = INT_MIN;
    sum = 0;
    int max_right = mid + 1;
    for (int j = mid + 1; j <= high; j++) {
        sum += A[j];
        if (sum > right_sum) {
            right_sum = sum;
            max_right = j;
        }
    }
    
    // 返回跨越中点的最大子数组
    return (MaxSubarray){max_left, max_right, left_sum + right_sum};
}

MaxSubarray find_maximum_subarray(int A[], int low, int high) {
    // 基线条件:单个元素
    if (high == low) {
        return (MaxSubarray){low, high, A[low]};
    }
    
    int mid = (low + high) / 2;
    
    // 递归求解
    MaxSubarray left = find_maximum_subarray(A, low, mid);
    MaxSubarray right = find_maximum_subarray(A, mid + 1, high);
    MaxSubarray cross = find_max_crossing_subarray(A, low, mid, high);
    
    // 合并结果
    if (left.sum >= right.sum && left.sum >= cross.sum) {
        return left;
    } else if (right.sum >= left.sum && right.sum >= cross.sum) {
        return right;
    } else {
        return cross;
    }
}

// 可视化求解过程
void print_subarray(int A[], int low, int high, int depth) {
    for (int i = 0; i < depth; i++) printf("| ");
    printf("Subarray [%d-%d]: ", low, high);
    for (int i = low; i <= high; i++) {
        printf("%d ", A[i]);
    }
    printf("\n");
}

4. 矩阵乘法:Strassen算法

4.1 问题分析

朴素矩阵乘法

void matrix_multiply(int **A, int **B, int **C, int n) {
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            C[i][j] = 0;
            for (int k = 0; k < n; k++) {
                C[i][j] += A[i][k] * B[k][j];
            }
        }
    }
}
// 时间复杂度:O(n³)

4.2 Strassen分治策略

算法步骤

  1. 将矩阵A、B和C分解为4个(n/2)×(n/2)子矩阵
  2. 创建10个(n/2)×(n/2)矩阵S₁~S₁₀
  3. 递归计算7个矩阵积P₁~P₇
  4. 通过P矩阵计算C的四个子矩阵

子矩阵计算

P₁ = A₁₁(S₁ = B₁₂ - B₂₂)
P₂ = S₂(A₁₁ + A₁₂)B₂₂
P₃ = S₃(A₂₁ + A₂₂)B₁₁
P₄ = A₂₂(S₄ = B₂₁ - B₁₁)
P₅ = S₅(A₁₁ + A₂₂)(B₁₁ + B₂₂)
P₆ = S₆(A₁₂ - A₂₂)(B₂₁ + B₂₂)
P₇ = S₇(A₁₁ - A₂₁)(B₁₁ + B₁₂)

C₁₁ = P₅ + P₄ - P₂ + P₆
C₁₂ = P₁ + P₂
C₂₁ = P₃ + P₄
C₂₂ = P₅ + P₁ - P₃ - P₇

4.3 C语言实现

// 矩阵分块操作
void matrix_partition(int **M, int **M11, int **M12, int **M21, int **M22, int n) {
    int half = n / 2;
    for (int i = 0; i < half; i++) {
        for (int j = 0; j < half; j++) {
            M11[i][j] = M[i][j];
            M12[i][j] = M[i][j + half];
            M21[i][j] = M[i + half][j];
            M22[i][j] = M[i + half][j + half];
        }
    }
}

// 矩阵合并操作
void matrix_merge(int **M, int **M11, int **M12, int **M21, int **M22, int half) {
    for (int i = 0; i < half; i++) {
        for (int j = 0; j < half; j++) {
            M[i][j] = M11[i][j];
            M[i][j + half] = M12[i][j];
            M[i + half][j] = M21[i][j];
            M[i + half][j + half] = M22[i][j];
        }
    }
}

// Strassen核心算法
void strassen_multiply(int **A, int **B, int **C, int n) {
    // 基线条件:小矩阵使用朴素算法
    if (n <= 64) {
        matrix_multiply(A, B, C, n);
        return;
    }
    
    int half = n / 2;
    
    // 分配子矩阵内存
    int **A11 = allocate_matrix(half), **A12 = allocate_matrix(half);
    int **A21 = allocate_matrix(half), **A22 = allocate_matrix(half);
    int **B11 = allocate_matrix(half), **B12 = allocate_matrix(half);
    int **B21 = allocate_matrix(half), **B22 = allocate_matrix(half);
    
    // 分块
    matrix_partition(A, A11, A12, A21, A22, n);
    matrix_partition(B, B11, B12, B21, B22, n);
    
    // 创建S矩阵
    int **S1 = allocate_matrix(half), **S2 = allocate_matrix(half);
    // ... 共10个S矩阵
    
    // 创建P矩阵
    int **P1 = allocate_matrix(half), **P2 = allocate_matrix(half);
    // ... 共7个P矩阵
    
    // 计算S矩阵
    matrix_sub(B12, B22, S1, half); // S1 = B12 - B22
    matrix_add(A11, A12, S2, half); // S2 = A11 + A12
    // ... 其他S矩阵
    
    // 递归计算P矩阵
    strassen_multiply(A11, S1, P1, half); // P1 = A11 * S1
    strassen_multiply(S2, B22, P2, half); // P2 = S2 * B22
    // ... 其他P矩阵
    
    // 计算C的子矩阵
    int **C11 = allocate_matrix(half), **C12 = allocate_matrix(half);
    int **C21 = allocate_matrix(half), **C22 = allocate_matrix(half);
    
    // C11 = P5 + P4 - P2 + P6
    matrix_add(P5, P4, C11, half);
    matrix_sub(C11, P2, C11, half);
    matrix_add(C11, P6, C11, half);
    
    // C12 = P1 + P2
    matrix_add(P1, P2, C12, half);
    
    // C21 = P3 + P4
    matrix_add(P3, P4, C21, half);
    
    // C22 = P5 + P1 - P3 - P7
    matrix_add(P5, P1, C22, half);
    matrix_sub(C22, P3, C22, half);
    matrix_sub(C22, P7, C22, half);
    
    // 合并结果
    matrix_merge(C, C11, C12, C21, C22, half);
    
    // 释放内存
    free_matrix(A11, half); // 释放所有临时矩阵
    // ...
}

// 性能对比
void performance_test() {
    int sizes[] = {128, 256, 512, 1024};
    printf("Size\tNaive(ms)\tStrassen(ms)\tSpeedup\n");
    
    for (int i = 0; i < 4; i++) {
        int n = sizes[i];
        int **A = random_matrix(n);
        int **B = random_matrix(n);
        int **C1 = allocate_matrix(n);
        int **C2 = allocate_matrix(n);
        
        clock_t start = clock();
        matrix_multiply(A, B, C1, n);
        double naive_time = (double)(clock() - start) * 1000 / CLOCKS_PER_SEC;
        
        start = clock();
        strassen_multiply(A, B, C2, n);
        double strassen_time = (double)(clock() - start) * 1000 / CLOCKS_PER_SEC;
        
        printf("%d\t%.2f\t\t%.2f\t\t%.2fx\n", 
               n, naive_time, strassen_time, naive_time / strassen_time);
        
        free_matrix(A, n);
        free_matrix(B, n);
        free_matrix(C1, n);
        free_matrix(C2, n);
    }
}

性能对比结果

矩阵规模 朴素算法(ms) Strassen(ms) 加速比
128×128 120.5 85.2 1.41x
256×256 965.3 512.7 1.88x
512×512 7,850.6 3,120.4 2.52x
1024×1024 63,200.8 21,450.3 2.95x

5. 最近点对问题

5.1 问题定义

给定平面上的n个点,找到距离最近的两个点

5.2 分治算法步骤

  1. 按x坐标排序点集
  2. 递归求解左右两半的最近点对
  3. 考虑跨分割线的点对(带状区域)
  4. 在带状区域中按y坐标排序并检查有限个点

5.3 C语言实现

typedef struct {
    double x;
    double y;
} Point;

double distance(Point p1, Point p2) {
    double dx = p1.x - p2.x;
    double dy = p1.y - p2.y;
    return sqrt(dx*dx + dy*dy);
}

double closest_pair(Point points[], int n) {
    // 基线条件
    if (n <= 3) {
        double min_dist = DBL_MAX;
        for (int i = 0; i < n; i++) {
            for (int j = i+1; j < n; j++) {
                double dist = distance(points[i], points[j]);
                if (dist < min_dist) min_dist = dist;
            }
        }
        return min_dist;
    }
    
    // 按x坐标排序
    qsort(points, n, sizeof(Point), compare_x);
    
    // 分割点集
    int mid = n / 2;
    Point mid_point = points[mid];
    
    // 递归求解左右两半
    double dl = closest_pair(points, mid);
    double dr = closest_pair(points + mid, n - mid);
    double d = fmin(dl, dr);
    
    // 构建带状区域
    Point strip[n];
    int strip_size = 0;
    for (int i = 0; i < n; i++) {
        if (fabs(points[i].x - mid_point.x) < d) {
            strip[strip_size++] = points[i];
        }
    }
    
    // 按y坐标排序带状区域
    qsort(strip, strip_size, sizeof(Point), compare_y);
    
    // 检查带状区域内的点
    double min_strip = d;
    for (int i = 0; i < strip_size; i++) {
        // 只需检查后续7个点(数学证明)
        for (int j = i+1; j < strip_size && (strip[j].y - strip[i].y) < min_strip; j++) {
            double dist = distance(strip[i], strip[j]);
            if (dist < min_strip) min_strip = dist;
        }
    }
    
    return fmin(d, min_strip);
}

时间复杂度分析

T(n) = 2T(n/2) + O(n log n)  // 排序带状区域
     = 2T(n/2) + O(n)       // 优化:归并排序
     = O(n log n)

6. 分治策略优化技巧

6.1 避免重复计算

矩阵乘法的缓存优化

void matrix_multiply_optimized(int **A, int **B, int **C, int n) {
    // 分块优化
    const int BLOCK_SIZE = 32;
    for (int i = 0; i < n; i += BLOCK_SIZE) {
        for (int j = 0; j < n; j += BLOCK_SIZE) {
            for (int k = 0; k < n; k += BLOCK_SIZE) {
                // 处理分块
                for (int ii = i; ii < i + BLOCK_SIZE && ii < n; ii++) {
                    for (int kk = k; kk < k + BLOCK_SIZE && kk < n; kk++) {
                        for (int jj = j; jj < j + BLOCK_SIZE && jj < n; jj++) {
                            C[ii][jj] += A[ii][kk] * B[kk][jj];
                        }
                    }
                }
            }
        }
    }
}

6.2 混合策略

快速排序与插入排序混合

void hybrid_quick_sort(int arr[], int low, int high) {
    while (high - low > 0) {
        // 小数组使用插入排序
        if (high - low < 16) {
            insertion_sort(arr + low, high - low + 1);
            return;
        }
        
        // 分区操作
        int pi = partition(arr, low, high);
        
        // 优化递归:先处理较短的子数组
        if (pi - low < high - pi) {
            hybrid_quick_sort(arr, low, pi - 1);
            low = pi + 1;
        } else {
            hybrid_quick_sort(arr, pi + 1, high);
            high = pi - 1;
        }
    }
}

6.3 并行化分治算法

#include 

void parallel_merge_sort(int arr[], int low, int high) {
    if (low < high) {
        int mid = (low + high) / 2;
        
        #pragma omp parallel sections
        {
            #pragma omp section
            parallel_merge_sort(arr, low, mid);
            
            #pragma omp section
            parallel_merge_sort(arr, mid + 1, high);
        }
        
        merge(arr, low, mid, high);
    }
}

// 性能对比(8核CPU):
// n=10,000,000: 串行 2.8s, 并行 0.4s, 加速比7x

总结与思考

本章深入探讨了分治策略的核心原理与应用:

  1. 递归式求解:代入法、递归树法、主方法
  2. 经典问题实现:最大子数组、矩阵乘法、最近点对
  3. 优化技巧:避免重复计算、混合策略、并行化
  4. 复杂度分析:理解算法效率的数学基础

关键洞见:分治策略通过将大问题分解为小问题,利用递归和合并解决复杂问题。其效率取决于子问题分解的平衡性和合并操作的成本。

下章预告:第五章《概率分析与随机算法》将探讨:

  • 随机算法的设计与分析
  • 概率论在算法中的应用
  • 抽样与随机选择算法
  • 哈希表的随机化分析

本文完整代码已上传至GitHub仓库:Algorithm-Implementations

思考题

  1. 在Strassen算法中,为什么当矩阵规模较小时要切换回朴素算法?
  2. 如何证明最近点对算法中带状区域只需检查7个点?
  3. 分治策略在哪些情况下可能不是最优选择?
  4. 如何将分治策略应用于机器学习算法(如决策树训练)?

你可能感兴趣的:(算法导论,数据结构与算法,算法,数据结构,c语言,性能优化)