给定两个大小为 m 和 n 的有序数组 nums1 和 nums2。
请你找出这两个有序数组的中位数,并且要求算法的时间复杂度为 O(log(m + n))。
你可以假设 nums1 和 nums2 不会同时为空。
示例 1:
nums1 = [1, 3]
nums2 = [2]
则中位数是 2.0
示例 2:
nums1 = [1, 2]
nums2 = [3, 4]
则中位数是 (2 + 3)/2 = 2.5
来源:力扣(LeetCode)
链接:https://leetcode-cn.com/problems/median-of-two-sorted-arrays
著作权归领扣网络所有。商业转载请联系官方授权,非商业转载请注明出处。
根据题目提示信息,算法的时间复杂度可达到 O(log(m + n))。log复杂度提示,可考虑采用二分法进行求解。
在两数组中,对应中位数位置的数,需满足:
1)左边数字个数等于右边;
2)左边最大数字小于右边最小数字(由于单个数组已实现有序,只需交叉考虑和另一数组是否符合)。
优化思路:
1)固定题目所给数组1、2的大小关系,便于求解边界情况;
2)对简单特殊情况先行判断,减小后续临界情况下求解工作量。
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int r1 = nums1.size(), i1 = r1 / 2;
int r2 = nums2.size(), i2 = r2 / 2;
int num = (r1 + r2) / 2 - 1; //下标从0开始,预先-1从而直接采用下标运算
bool odd;
if ((r1 + r2) % 2 != 0) {
odd = true;
}
else {
odd = false;
}
if (!r1) {
if (odd)
return nums2[i2];
else
return (double)(nums2[i2 - 1] + nums2[i2]) / 2;
}
else if (!r2) {
if (odd)
return nums1[i1];
else
return (double)(nums1[i1 - 1] + nums1[i1]) / 2;
}
if (r1 == 1 && r2 == 1) {
return (double)(nums1[0] + nums2[0]) / 2;
}
else if (r1 == 1) {
if (odd)
return (nums1[0] > nums2[i2]) ? nums2[i2] : max(nums1[0], nums2[i2 - 1]);
else
return (nums1[0] > nums2[i2]) ? (double)(nums2[i2] + min(nums2[i2 + 1], nums1[0])) / 2 : (double)(nums2[i2] + max(nums2[i2 - 1], nums1[0])) / 2;
}
else if (r2 == 1) {
if (odd)
return (nums2[0] > nums1[i1]) ? nums1[i1] : max(nums2[0], nums1[i1 - 1]);
else
return (nums2[0] > nums1[i1]) ? (double)(nums1[i1] + min(nums1[i1 + 1], nums2[0])) / 2 : (double)(nums1[i1] + max(nums1[i1 - 1], nums2[0])) / 2;
}
while (((i1 + i2) != num) || ((i2 + 1) < nums2.size() && (nums1[i1] > nums2[i2 + 1]) && i1 && i2) || ((i1 + 1) < nums1.size() && nums2[i2] > nums1[i1 + 1] && i1 && i2)) {
if ((i1 + i2) >= num) {
if (!i1 || !i2) {
if (!i1) {
r2 = i2;
i2 = i2 / 2;
}
else {
r1 = i1;
i1 = i1 / 2;
}
}
else if (nums1[i1] >= nums2[i2]) {
r1 = i1;
i1 = i1 / 2;
}
else {
r2 = i2;
i2 = i2 / 2;
}
}
else {
if (i1 == nums1.size() - 1) {
i2 = (i2 + r2) / 2;
}
else if (i2 == nums2.size() - 1) {
i1 = (i1 + r1) / 2;
}
else if (nums1[i1 + 1] >= nums2[i2 + 1]) {
i2 = (i2 + r2) / 2;
}
else {
i1 = (i1 + r1) / 2;
}
}
}
if (odd) {
if (i1 == 0) {
if (i2 == (nums2.size() - 1))
return max(nums1[0], nums2[i2]);
else if (nums1[0] > nums2[i2 + 1])
return nums2[i2 + 1];
else
return max(nums1[0], nums2[i2]);
}
else if (i2 == 0) {
if (i1 == (nums1.size() - 1))
return max(nums1[i1], nums2[0]);
else if (nums2[0] > nums1[i1 + 1])
return nums1[i1 + 1];
else
return max(nums1[i1], nums2[0]);
}
else {
return max(nums1[i1], nums2[i2]);
}
}
else {
if (i1 == 0) {
if (i2 == (nums2.size() - 1)) {
if (i2)
return (double)(nums2[i2] + max(nums2[i2 - 1], nums1[0])) / 2;
else
return (double)(nums2[i2] + nums1[0]) / 2;
}
else {
if (nums1[0] > nums2[i2])
return (double)(min(nums2[i2 + 1], nums1[0]) + nums2[i2]) / 2;
else
return (double)(max(nums2[i2 - 1], nums1[0]) + nums2[i2]) / 2;
}
}
else if (i2 == 0) {
if (i1 == (nums1.size() - 1)) {
if (i1)
return (double)(nums1[i1] + max(nums1[i1 - 1], nums2[0])) / 2;
else
return (double)(nums1[i1] + nums2[0]) / 2;
}
else {
if (nums2[0] > nums1[i1])
return (double)(min(nums1[i1 + 1], nums2[0]) + nums1[i1]) / 2;
else
return (double)(max(nums1[i1 - 1], nums2[0]) + nums1[i1]) / 2;
}
}
else {
if (i1 == nums1.size() - 1) {
if (nums2[i2] > nums1[i1])
return (double)(max(nums1[i1], nums2[i2 - 1]) + nums2[i2]) / 2;
else
return (double)(max(nums1[i1 - 1], nums2[i2]) + nums1[i1]) / 2;
}
else if (i2 == nums2.size() - 1) {
if (nums1[i1] > nums2[i2])
return (double)(max(nums2[i2], nums1[i1 - 1]) + nums1[i1]) / 2;
else
return (double)(max(nums2[i2 - 1], nums1[i1]) + nums2[i2]) / 2;
}
else {
if (nums1[i1] > nums2[i2])
return (double)(max(nums1[i1 - 1], nums2[i2]) + nums1[i1]) / 2;
else
return (double)(max(nums2[i2 - 1], nums1[i1]) + nums2[i2]) / 2;
}
}
}
return 0;
}
};
执行用时 : 20 ms, 在所有 C++ 提交中击败了 92.89% 的用户;
内存消耗 : 9.6 MB。
能够完成题目,但实现复杂、代码繁琐。
上述方法虽能够实现,但过程中代码繁琐复杂,极易出错。
优化思路:
固定左边数组总数,减小判断的复杂程度。
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
if (nums2.size() > nums1.size()) { //使nums1长度大于nums2
vector<int> numstmp = nums1;
nums1 = nums2;
nums2 = numstmp;
}
bool odd = (nums1.size() + nums2.size()) % 2;
int length = (nums1.size() + nums2.size() + 1) / 2;
if (nums1.size() == 1)
return odd ? nums1[0] : double (nums1[0] + nums2[0]) / 2;
if (nums2.size() == 1) {
if (odd)
return min(max(nums1[length - 2], nums2[0]), nums1[length - 1]);
else {
if (nums2[0] < nums1[length - 1])
return (double)(nums1[length - 1] + max(nums1[length - 2], nums2[0])) / 2;
else
return (double)(nums1[length - 1] + min(nums1[length], nums2[0])) / 2;
}
}
if (nums2.size() == 0 || (nums1.size() > nums2.size() && nums1[nums1.size() - 1] <= nums2[0]))
return odd ? nums1[length - 1] : (double)(nums1[length - 1] + nums1[length]) / 2;
if (nums1.size() + nums2.size() == 2)
return (double)(nums1[0] + nums1[1]) / 2;
length = (nums1.size() + nums2.size()) / 2 - 1;;//用于两数组坐标
int l1 = nums1.size() / 2;
int l2 = length - l1;//固定左边数字总数,减小变量
int r1 = nums1.size();
while ((l2 < (nums2.size() - 1) && nums1[l1] > nums2[l2 + 1]) || (l2 && nums2[l2] > nums1[l1 + 1])) {
if (l2 < (nums2.size() - 1) && nums1[l1] > nums2[l2 + 1]) {
r1 = l1;
l1 = r1 / 2;
l2 = length - l1;
if (l2 > nums2.size() - 1) {l2 = nums2.size() - 1; l1 = length - l2;}
}
else {
l1 = (l1 + r1 + 1) / 2; //+1避免l1 = r1 - 1时进入死循环
l2 = length - l1;
if (l2 < 0) { l2 = 0; l1 = length - l2; }
}
}
if (l2 == nums2.size() - 1 && nums1.size() != nums2.size())
return odd ? max(nums1[l1], nums2[l2]) : (double)(max(nums1[l1 - 1], nums2[l2]) + max(nums1[l1], nums2[l2 - 1])) / 2;
else if (l2 == 0 && l1 < nums1.size() - 2)
return odd ? max(min(nums1[l1 + 1], nums2[0]), nums1[l1]) : (double)(nums1[l1] + max(min(nums1[l1 + 1], nums2[0]), nums1[l1 - 1])) / 2;
else if (l2 == 0)
return odd ? max(nums1[l1], nums2[l2]) : (double)(nums1[l1] + max(nums2[0], nums1[l1 - 1])) / 2;
else if (l1 == 0)
return odd ? max(nums1[l1], nums2[l2]) : (double)(nums2[l2] + max(nums1[0], nums2[l2 - 1])) / 2;
else
return odd ? max(nums1[l1], nums2[l2]) : (double)(max(nums1[l1 - 1], nums2[l2]) + max(nums1[l1], nums2[l2 - 1])) / 2;
return 0;
}
};
执行用时 : 20 ms, 在所有 C++ 提交中击败了 92.89% 的用户;
内存消耗 : 9.7 MB。
代码繁琐程度略有优化,但各类临界值处理复杂,仍极易出错。
迭代优化思路:
设置新的变量 l 和 r 作为迭代判断用的下标,根据数组对应值判断调整下标,避免复杂的判断过程。
临界值处理优化思路:
引入最大、最小整数作为临界对应值,统一输出结果。
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int n = nums1.size(), m = nums2.size();
if (n > m)
return findMedianSortedArrays(nums2, nums1);
int length = (n + m + 1) / 2;
int l = 0, r = n;
int i = (l + r) / 2;
int j = length - i;
while (l < r) {
if (nums1[i] >= nums2[j - 1])
r = i;
else
l = i + 1;
i = (l + r) / 2;
j = length - i;
}
int lnum = max(i > 0 ? nums1[i - 1] : INT_MIN, j > 0 ? nums2[j - 1] : INT_MIN);
if ((n + m) % 2)
return lnum;
int rnum = min(i < n ? nums1[i] : INT_MAX, j < m ? nums2[j] : INT_MAX);
return (double)(lnum + rnum) / 2;
}
};
执行用时 : 20 ms, 在所有 C++ 提交中击败了 92.89% 的用户;
内存消耗 : 9.6 MB。
思路清晰,代码简洁。
如果输入数组的长度不符合我们的期望,可采用两种调整方式。
方法1(直接调整):
if (nums2.size() > nums1.size()) {
vector<int> numstmp = nums1;
nums1 = nums2;
nums2 = numstmp;
}
方法2(引用自身函数):
if (nums2.size() > nums1.size()) {
return findMedianSortedArrays(nums2, nums1);
}
在处理临界值时很好用。
INT_MIN= -2^31
INT_MAX = 2^31-1
本题求解耗时较长,主要原因是前两种解法的判断思路和临界值处理过于复杂,极易出错,花费大量时间debug后才能正确实现。总结经验教训如下:
(1)确定大致方向后,先理清、优化思路再coding,事半功倍;
(2)对于临界值多思考系统化的处理方法,尽量避免一个个单独求解。