Skip to content

4. 寻找两个正序数组的中位数

题目描述

给定两个大小分别为 mn 的正序(从小到大)数组 nums1 和 nums2。请你找出并返回这两个正序数组的 中位数

算法的时间复杂度应该为 O(log (m+n))

 

示例 1:

输入:nums1 = [1,3], nums2 = [2]
输出:2.00000
解释:合并数组 = [1,2,3] ,中位数 2

示例 2:

输入:nums1 = [1,2], nums2 = [3,4]
输出:2.50000
解释:合并数组 = [1,2,3,4] ,中位数 (2 + 3) / 2 = 2.5

 

 

提示:

  • nums1.length == m
  • nums2.length == n
  • 0 <= m <= 1000
  • 0 <= n <= 1000
  • 1 <= m + n <= 2000
  • -106 <= nums1[i], nums2[i] <= 106

方法一:分治

题目要求算法的时间复杂度为 O(log(m+n)),因此不能直接遍历两个数组,而是需要使用二分查找的方法。

如果 m+n 是奇数,那么中位数就是第 m+n+12 个数;如果 m+n 是偶数,那么中位数就是第 m+n+12 和第 m+n+22 个数的平均数。实际上,我们可以统一为求第 m+n+12 个数和第 m+n+22 个数的平均数。

因此,我们可以设计一个函数 f(i,j,k),表示在数组 nums1 的区间 [i,m) 和数组 nums2 的区间 [j,n) 中,求第 k 小的数。那么中位数就是 f(0,0,m+n+12)f(0,0,m+n+22) 的平均数。

函数 f(i,j,k) 的实现思路如下:

  • 如果 im,说明数组 nums1 的区间 [i,m) 为空,因此直接返回 nums2[j+k1]
  • 如果 jn,说明数组 nums2 的区间 [j,n) 为空,因此直接返回 nums1[i+k1]
  • 如果 k=1,说明要找第一个数,因此只需要返回 nums1[i]nums2[j] 中的最小值;
  • 否则,我们分别在两个数组中查找第 k2 个数,设为 xy。(注意,如果某个数组不存在第 k2 个数,那么我们将第 k2 个数视为 +。)比较 xy 的大小:
    • 如果 xy,则说明数组 nums1 的第 k2 个数不可能是第 k 小的数,因此我们可以排除数组 nums1 的区间 [i,i+k2),递归调用 f(i+k2,j,kk2)
    • 如果 x>y,则说明数组 nums2 的第 k2 个数不可能是第 k 小的数,因此我们可以排除数组 nums2 的区间 [j,j+k2),递归调用 f(i,j+k2,kk2)

时间复杂度 O(log(m+n)),空间复杂度 O(log(m+n))。其中 mn 分别是数组 nums1nums2 的长度。

java
class Solution {
    private int m;
    private int n;
    private int[] nums1;
    private int[] nums2;

    public double findMedianSortedArrays(int[] nums1, int[] nums2) {
        m = nums1.length;
        n = nums2.length;
        this.nums1 = nums1;
        this.nums2 = nums2;
        int a = f(0, 0, (m + n + 1) / 2);
        int b = f(0, 0, (m + n + 2) / 2);
        return (a + b) / 2.0;
    }

    private int f(int i, int j, int k) {
        if (i >= m) {
            return nums2[j + k - 1];
        }
        if (j >= n) {
            return nums1[i + k - 1];
        }
        if (k == 1) {
            return Math.min(nums1[i], nums2[j]);
        }
        int p = k / 2;
        int x = i + p - 1 < m ? nums1[i + p - 1] : 1 << 30;
        int y = j + p - 1 < n ? nums2[j + p - 1] : 1 << 30;
        return x < y ? f(i + p, j, k - p) : f(i, j + p, k - p);
    }
}
cpp
class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        int m = nums1.size(), n = nums2.size();
        function<int(int, int, int)> f = [&](int i, int j, int k) {
            if (i >= m) {
                return nums2[j + k - 1];
            }
            if (j >= n) {
                return nums1[i + k - 1];
            }
            if (k == 1) {
                return min(nums1[i], nums2[j]);
            }
            int p = k / 2;
            int x = i + p - 1 < m ? nums1[i + p - 1] : 1 << 30;
            int y = j + p - 1 < n ? nums2[j + p - 1] : 1 << 30;
            return x < y ? f(i + p, j, k - p) : f(i, j + p, k - p);
        };
        int a = f(0, 0, (m + n + 1) / 2);
        int b = f(0, 0, (m + n + 2) / 2);
        return (a + b) / 2.0;
    }
};
ts
function findMedianSortedArrays(nums1: number[], nums2: number[]): number {
    const m = nums1.length;
    const n = nums2.length;
    const f = (i: number, j: number, k: number): number => {
        if (i >= m) {
            return nums2[j + k - 1];
        }
        if (j >= n) {
            return nums1[i + k - 1];
        }
        if (k == 1) {
            return Math.min(nums1[i], nums2[j]);
        }
        const p = Math.floor(k / 2);
        const x = i + p - 1 < m ? nums1[i + p - 1] : 1 << 30;
        const y = j + p - 1 < n ? nums2[j + p - 1] : 1 << 30;
        return x < y ? f(i + p, j, k - p) : f(i, j + p, k - p);
    };
    const a = f(0, 0, Math.floor((m + n + 1) / 2));
    const b = f(0, 0, Math.floor((m + n + 2) / 2));
    return (a + b) / 2;
}
python
class Solution:
    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        def f(i: int, j: int, k: int) -> int:
            if i >= m:
                return nums2[j + k - 1]
            if j >= n:
                return nums1[i + k - 1]
            if k == 1:
                return min(nums1[i], nums2[j])
            p = k // 2
            x = nums1[i + p - 1] if i + p - 1 < m else inf
            y = nums2[j + p - 1] if j + p - 1 < n else inf
            return f(i + p, j, k - p) if x < y else f(i, j + p, k - p)

        m, n = len(nums1), len(nums2)
        a = f(0, 0, (m + n + 1) // 2)
        b = f(0, 0, (m + n + 2) // 2)
        return (a + b) / 2

Released under the MIT License.