LeetCode Median of Two Sorted Arrays

  1 #include <iostream>

  2 #include <cstdlib>

  3 #include <cmath>

  4 #include <algorithm>

  5 

  6 using namespace std;

  7 

  8 class Solution {

  9 public:

 10     double findMedianSortedArrays(int A[], int m, int B[], int n) {

 11         double ma = 0;

 12         double mb = 0;

 13 

 14         bool empty_a = A == NULL || m < 1;

 15         bool empty_b = B == NULL || n < 1;

 16         

 17         if (!empty_a) ma = (A[(m - 1) / 2] + A[m/2]) / 2.0;

 18         if (!empty_b) mb = (B[(n - 1) / 2] + B[n/2]) / 2.0;

 19         

 20         if (empty_a && empty_b) { // will this happen ?

 21             return 0;

 22         } else if (empty_a) {

 23             return mb;

 24         } else if (empty_b) {

 25             return ma;

 26         }

 27         

 28         double low = 0, high = 0;

 29 

 30         if (ma > mb) {

 31             low = mb, high = ma;

 32         } else if (ma < mb) {

 33             low = ma, high = mb;

 34         } else {

 35             return ma;

 36         }

 37         

 38         double precise = 0.1;

 39         double mv = 0;

 40         int total = m + n;

 41         int half  = total / 2;

 42         bool declared = false;

 43         while(high - low > precise) {

 44             mv = (high + low) / 2.0;

 45             int* pa = lower_bound(A, A + m, mv);

 46             int* pb = lower_bound(B, B + n, mv);

 47             int lh = (pa - A) + (pb - B);

 48 

 49             if (lh < half) {        // the median assumed is too small, so increase it

 50                 low = mv;

 51             } else if (lh > half) { // the median assumed is too big, so decrease it

 52                 high= mv;

 53             } else {

 54                 declared = true;

 55                 // divided into odd/even case. should re-calculate the mv

 56                 // for even case median calculated from two adjacent numbers in

 57                 // the merged array, we assume that one is mmore and the other

 58                 // is mless (median = (mmore + mless) / 2.0 )

 59                 int mmore = 0;

 60                 // find bigger number to compute median for even case.

 61                 if (pa == A + m && pb == B + n) {

 62                     // should not happen;

 63                     cout<<"[1]should not happen"<<endl;

 64                 } else if (pa == A + m) {

 65                     mmore = *pb;

 66                 } else if (pb == B + n) {

 67                     mmore = *pa;

 68                 } else {

 69                     if (*pa < *pb) {

 70                         mmore = *pa;

 71                     } else {

 72                         mmore = *pb;

 73                     }

 74                 }

 75                 

 76                 // for odd case. the mv is equal to value of mmore

 77                 if (half * 2 != total) {

 78                     mv = mmore;

 79                     break;

 80                 }

 81                 

 82                 // find samller number to compute median for even case.

 83                 pa--, pb--;

 84                 int mless = 0;

 85                 if (pa < A && pb < B) {

 86                     // should not happen

 87                     cout<<"[2]should not happen"<<endl;

 88                 } else if (pa < A) {

 89                     mless = *pb;

 90                 } else if (pb < B) {

 91                     mless = *pa;

 92                 } else {

 93                     if (*pb > * pa) {

 94                         mless = *pb;

 95                     } else {

 96                         mless = *pa;

 97                     }

 98                 }

 99                 mv = (mless + mmore) / 2.0;

100                 break;

101             }

102         }

103         if (declared) { // median value is on the boundary

104             return mv;

105         }

106         if (fabs(mv - ma) < fabs(mv - mb)) {

107             return ma;

108         } else {

109             return mb;

110         }

111     }

112 };

113 

114 int main() {

115     Solution s;

116     int A[] = {1, 1};

117     int B[] = {1, 2};

118     int m = sizeof(A) / sizeof(A[0]);

119     int n = sizeof(B) / sizeof(B[0]);

120     

121     cout<<s.findMedianSortedArrays(A, m, B, n)<<endl;

122     system("pause");

123     return 0;

124 }

写得好乱啊, 这个还是二分搜索吧,只不过用来决定选择前半部还是后半部的评价标准变了,由原来的与一个确定常数数比较变为两个变量之间的比较(lh 与 half之间的数量关系),搜索空间由一个数组变为一个数值区间(其实都可以看做解的值域)。230ms+。

题目中提到"The overall run time complexity should be O(log (m+n)).",其实log前面有常数,由于数据是整数,经过32次二分搜索,可以使数值空间降到1以内,再过4次可以降到0.1内。

 

再用O(n)的简单解法感觉时间上差不多,不知为何

 1 class Solution {

 2 public:

 3     double findMedianSortedArrays(int A[], int m, int B[], int n) {

 4         int ia = 0, ib = 0;

 5         int it = -1;

 6         int im = (m + n - 1) / 2;

 7         int val= 0;

 8 

 9         bool empty_a = A == NULL || m < 1;

10         bool empty_b = B == NULL || n < 1;

11         

12         while (!empty_a && ia < m && !empty_b && ib < n && it < im) {

13             if (A[ia] < B[ib]) {

14                 val = A[ia++];

15             } else {

16                 val = B[ib++];

17             }

18             ++it;

19         }

20         

21         while (!empty_a && ia < m && it < im) {

22             val = A[ia++];

23             it++;

24         }

25         while (!empty_b && ib < n && it < im) {

26             val = B[ib++];

27             it++;

28         }

29         if ((m + n) & 1) {

30             return val;

31         } else {

32             int val2 = 0;

33             if ((empty_a || ia >= m) && (empty_b || ib >= n)) {

34                 // should not happen

35             } else if (empty_a || ia >= m) {

36                 val2 = B[ib];

37             } else if (empty_b || ib >= n) {

38                 val2 = A[ia];

39             } else {

40                 val2 = A[ia] > B[ib] ? B[ib] : A[ia];

41             }

42             return (val + val2) / 2.0;

43         }

44     }

45 };

 在discuss里找到一份log(m+n)的代码:

class Solution {

public:

    double findMedianSortedArrays(int A[], int m, int B[], int n) {

        int length=m+n;

        if(length%2)return findkth(A, m, B, n, length/2+1);

        else return (double(findkth(A, m, B, n, length/2))+findkth(A, m, B, n, length/2+1))/2;

    }

    int findkth(int A[],int m,int B[], int n, int k){

        if(m>n)

            return findkth(B, n, A, m,k);

        if(m==0)return B[k-1];

        if(k==1)return A[0]<B[0]?A[0]:B[0];

        int pa=k/2<m?k/2:m;

        int pb=k-pa;

        if(A[pa-1]==B[pb-1]){return A[pa-1];}

        if(A[pa-1]<B[pb-1])

            return findkth(A+pa, m-pa, B, pb, k-pa);

        else

            return findkth(A,pa,B+pb,n-pb,k-pb);

    }

};

花点时间理解:

LeetCode Median of Two Sorted Arrays

下面先不考虑k/2>=Na, 及K=1(K=1时比较两数组首元素即可得出)的情况,数组下标从0开始。取第K个数的算法,首先取pa=k/2, pb=k-k/2;这样使得{A[0], A[1]...A[pa-1]}的元素数目加上{B[0], B[1]...B[pb-1]}的元素数目刚好等于k。此时如果:

1. A[pa-1] = B[pb-1],那么很容易知道A[pa-1]或者说B[pb-1]就是第K个数。因为数组是已排序的,且|{A[0], A[1]...A[pa-1]}| + |{B[0], B[1]...B[pb-1]}| = K

2. A[pa-1] < B[pb-1],那么可以认为第K个数肯定不在数组A的[0, pa-1]这个区间内。用反证法可以证明:

假设第K个数存在于A[0...pa-1]中,设其为X,则根据第K个数的含义,其前面必然存在K-1个数小于等于X。但由于X是在A[0...pa-1]中被找到的,而数组A中这样的数最多只有(即A[pa-1]为中位数时):|{A[0], A[1]...A[pa-1]}| - 1= k/2 - 1 < K-1。剩下的数需要从B数组中取,至少需要K - 1 - (K/2 - 1) = K - K/2个数。但由于存在条件A[pa-1] < B[pb-1],B数组中的第K-K/2个数即B[pb-1]要比X大,产生矛盾,故假设不成立。所以第K个数肯定不在数组A的[0, pa-1]这个区间内,此时我们只需要在剩下的区间内搜索就可以了,寻找第K大的元素变为寻找第(K-pa)大的元素(因为我们已经排除了数组A中前pa个元素)

3. A[pa-1] > B[pb-1],这种情况是第二种情况的对称情况。即可以排除{B[0], B[1]...B[pb-1]}这个搜索区间,并继续寻找第(K-pb)大的元素

 

再来一次:

class Solution {

public:

    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {

        int len1 = nums1.size();

        int len2 = nums2.size();

        int total= len1 + len2;

        if (total & 0x1) {

            return findK(&nums1[0], &nums2[0], len1, len2, total / 2 + 1);

        } else {

            double lo = findK(&nums1[0], &nums2[0], len1, len2, total / 2);

            double hi = findK(&nums1[0], &nums2[0], len1, len2, total / 2 + 1);

            return (lo + hi) / 2;

        }

    }

    

    int findK(const int* a, const int* b, int na, int nb, int k) {

        if (nb < na) {

            return findK(b, a, nb, na, k);

        }

        

        if (na == 0) {

            return b[k - 1];

        }

        if (k == 1) {

            return a[0] > b[0] ? b[0] : a[0];

        }

        

        int pa = k / 2 < na ? k / 2 : na;

        int pb = k - pa;

        if (a[pa - 1] == b[pb - 1]) {

            return a[pa - 1];

        } else if (a[pa - 1] < b[pb - 1]) {

            return findK(a + pa, b, na - pa, nb, k - pa);

        } else {

            return findK(a, b + pb, na, nb - pb, k - pb);

        }

    }

};

 

你可能感兴趣的:(LeetCode)