MedianOfTwoSortedArrays [source code]

public class MedianOfTwoSortedArrays {
static
/******************************************************************************/
class Solution {
    public double findMedianSortedArrays(int[] nums1, int[] nums2) {
        int m = nums1.length, n = nums2.length;
        if (m > n)
            return findMedianSortedArrays (nums2, nums1);
        int lo = 0, hi = m, half = (m + n + 1) / 2;
        while (lo <= hi) {
            int i = lo + (hi - lo) / 2, j = half - i,
            left1 = i > 0 ? nums1[i - 1] : Integer.MIN_VALUE,
            right1 = i < m ? nums1[i] : Integer.MAX_VALUE,
            left2 = j > 0 ? nums2[j - 1] : Integer.MIN_VALUE,
            right2 = j < n ? nums2[j] : Integer.MAX_VALUE;
            if (left1 > right2)
                hi = i - 1;
            else if (left2 > right1)
                lo = i + 1;
            else {
                int max_left = Math.max (left1, left2);
                if ((m + n) % 2 == 1)
                    return max_left;
                int min_right = Math.min (right1, right2);
                return 0.5 * (max_left + min_right);
            }
        }
        return -1;
    }
}
/******************************************************************************/

    public static void main(String[] args) {
        MedianOfTwoSortedArrays.Solution tester = new MedianOfTwoSortedArrays.Solution();
        int[][] inputs = {
            {1, 3}, {2}, 
            {1, 2}, {3, 4},
            {1,1,1},{1,1,1},
            {1}, {1},
        };
        double[] answers = {
            2.0,
            2.5,
            1.0,
            1.0,
        };
        for (int i = 0; i < inputs.length / 2; i++) {
            int[] nums1 = inputs[2 * i], nums2 = inputs[2 * i + 1];
            double ans = answers[i];
            System.out.println (Printer.separator ());
            double output =tester.findMedianSortedArrays (nums1, nums2);
            System.out.printf ("[%s] and [%s] -> %s, expected: %f\n", 
                Printer.array (nums1), Printer.array (nums2), Printer.wrapColor (output + "", output == ans ? "green" : "red"), ans
            );
        }
    }
}

上算法课的时候做过的题目, 不过一点印象也没有了, 也算是一个经典题目了;

毫无思路, 直接看答案了;


editorial

Approach #1 Recursive Approach [Accepted]

To solve this problem, we need to understand "What is the use of median". In statistics, the median is used for:

Dividing a set into two equal length subsets, that one subset is always greater than the other.

他最后计算的时候, 认为一个要找到的条件是:

  • i + j = m - i + n - j ( + 1);
  • max (left_part) <= min (right_part)

可以看到, 尤其是第一个条件, 意思是左边半边的长度要么和右边相等, 要么就是最多只能小1; 不对, 反了, 是左边比右边大1;

另外, 第一个条件转换为i + j = (m + n + 1) / 2之后, 如果m + n是10, 得到的是5, 如果是9, 得到的还是5;

这个计算方式, 我后来查了一下463的定义, 跟老师给出的计算方式是一样的, 对于任何一个array, median定义在floor ((n + 1) / 2)的位置;

这里的解法:

class Solution {  
    public double findMedianSortedArrays(int[] A, int[] B) {  
        int m = A.length;  
        int n = B.length;  
        if (m > n) { // to ensure m<=n  
            int[] temp = A; A = B; B = temp;  
            int tmp = m; m = n; n = tmp;  
        }  
        int iMin = 0, iMax = m, halfLen = (m + n + 1) / 2;  
        while (iMin <= iMax) {  
            int i = (iMin + iMax) / 2;  
            int j = halfLen - i;  
            if (i < iMax && B[j-1] > A[i]){  
                iMin = i + 1; // i is too small  
            }  
            else if (i > iMin && A[i-1] > B[j]) {  
                iMax = i - 1; // i is too big  
            }  
            else { // i is perfect  
                int maxLeft = 0;  
                if (i == 0) { maxLeft = B[j-1]; }  
                else if (j == 0) { maxLeft = A[i-1]; }  
                else { maxLeft = Math.max(A[i-1], B[j-1]); }  
                if ( (m + n) % 2 == 1 ) { return maxLeft; }  

                int minRight = 0;  
                if (i == m) { minRight = B[j]; }  
                else if (j == n) { minRight = A[i]; }  
                else { minRight = Math.min(B[j], A[i]); }  

                return (maxLeft + minRight) / 2.0;  
            }  
        }  
        return 0.0;  
    }  
}

建议好好看看editorial的这个文章, 这个还算是一个比较复杂的算法了, 虽然代码看起来很直接, 但是背后的原理并不好理解;

不过这个作者的水平是比较蛇的, 因为这个答案截止到现在的版本都还有typo;

不过最后还是建议回头再看一下这个答案原来的文章, 这个算法的一个难点在于Corner Case的处理, 这个看上面的代码很容易忽略;

另外, 这个算法说到底还是一个binary search, 而且是一个有标准的三叉的算法, 所以很干脆的直接用闭区间的写法来写就行了, 很简单; 关键还是算法背后的数学的理解, 比如what is the index of the median, 以及对应的计算方式;

这个算法的复杂度是O(lg (min (M, N))).

最后上面的代码是参考各种答案写出来的, 其实基本是记下来的;

另外, 最后写完代码回头思考, 发现其实之前看别人代码的时候, 还是很多东西没有理解;

compare and generalize

一个问题就是, 为什么比如m + n是odd的时候, 最后的overall median就是max (A_left, B_left)? 这个我一开始也想不到怎么论证; 但是后来想到, 我们这里这个两个array的情况的median, 完全可以通过跟一个array的median进行类比来进行思考;

比如, 当你一个median的时候, 你也把这个N分成(N + 1) / 2 and (N - 1) / 2两段长度, 那么, 你下来怎么知道找到median? 实际上就是左边这一段的右端点; 在sorted的情况下, 实际上就是left的最大值;

这里也是类似的; 只要你根据binary search, 有把握的将M + N分成了合理的两段, 那么left部分的最大值就是你要的最大值; 那么这个最大值就是A的left半段的最大值和B的left半段的最大值的最大值;

even的情况, 相对还好理解一些;

definition and invariant

理解了上面一个问题, 那么也就理解了为什么我们一直把i + j = (M + N + 1) / 2这个东西奉为金科玉律. 事实上, 我们完全可以定义这一题到i + j = (M + N - 1) / 2, 只不过如果这样定义, 那么最后overall median的计算方式, 就要相应的换一下了, 比如, odd的情况下, 之前是算left的最大值; 那么我们现在就应该换成right的最小值了(因为现在odd的情况下right半段更加长);

这里想要表达的观点是, 这里这个长度计算的定义(i和j的关系), 并不是很固定的, 只是让你自己要很清楚的了解你最后计算的是什么样的一个termination的结果; 首先考虑怎样到达这个结果(binary search), 然后考虑怎样利用这个结果(找到median, 那么因此这个结果首先的选择就应该有意义);

这里的这个结果, 实际上就是binary search上面的invariant; 我们最后要找到的两个性质就是:

len (left) = (M + N + 1) / 2  
max (left) <= min (right)

有了这个结果我们就有把握找到最终的median;

所以你自己这个计算的性质怎么定义, 只要保证首先一直是consistent的(尤其是, 要和你的binary search里面的action consistent: 你的binary search里面的分叉的时候, 要合理的选择action来维护invariant), 然后保证最后是一个有用的性质就行了; 只要走到这样一个性质, 最后总是能有办法找到你想要的结果的, 并不是一个死板的东西;

当然, 对于一些有名的问题, 适当记忆一些东西, 比如这里的median问题的len (left ) = (N + 1) / 2这个东西, 稍微记忆一下是不过分的;


大概看了一下463上课的时候的那个解法, 那个解法比这里的editorial给出来的算法简练很多, 但是那个题目跟这个其实不是完全一样: 那里要求m == n, 而且这个条件在最后的证明当中还用到了;

但是后来网上搜了一下, 原来不等长度的时候, 居然还是有办法的, 但是写起来繁琐很多, 最后的代码完全不如这个editorial给出来的简练: https://www.geeksforgeeks.org/median-of-two-sorted-arrays-of-different-sizes/

关键是这个方法最后的复杂度是O(lgM + lgN), 这个实际上是没有达到题目的要求的: product大于sum;

后来大概重新看了一下dinitz给的那个答案, 好像跟我理解的不太一眼跟; dinitz的答案是有一句说, median的index是floor ((N + 1) / 2), 但是后来我看了一下, 他这个是指的1-based的, 因为他这个人一直就是1based的index用法;

所以我认为上面的editorial其实还是要重新理解一下, 尤其是floor ((N + 1) / 2)这个值, 其实是一个length概念, 而不是一个index概念, 这个也符合一开始对于i和j的定义: i + j是一个left part的length;


这个是discussion最优解, 其实就是editorial作者一开始在discussion上面发的帖子:

@MissMary said in Share my O(log(min(m,n)) solution with explanation:

To solve this problem, we need to understand "What is the use of median". In statistics, the median is used for dividing a set into two equal length subsets, that one subset is always greater than the other. If we understand the use of median for dividing, we are very close to the answer.

First let's cut A into two parts at a random position i:

      left_A             |        right_A  
A[0], A[1], ..., A[i-1]  |  A[i], A[i+1], ..., A[m-1]  

Since A has m elements, so there are m+1 kinds of cutting( i = 0 ~ m ). And we know: len(left_A) = i, len(right_A) = m - i . Note: when i = 0 , left_A is empty, and when i = m , right_A is empty.

With the same way, cut B into two parts at a random position j:

      left_B             |        right_B  
B[0], B[1], ..., B[j-1]  |  B[j], B[j+1], ..., B[n-1]  

Put left_A and left_B into one set, and put right_A and right_B into another set. Let's name them left_part and right_part :

      left_part          |        right_part  
A[0], A[1], ..., A[i-1]  |  A[i], A[i+1], ..., A[m-1]  
B[0], B[1], ..., B[j-1]  |  B[j], B[j+1], ..., B[n-1]  

If we can ensure:

1) len(left_part) == len(right_part)  
2) max(left_part) <= min(right_part)  

then we divide all elements in {A, B} into two parts with equal length, and one part is always greater than the other. Then median = (max(left_part) + min(right_part))/2.

To ensure these two conditions, we just need to ensure:

(1) i + j == m - i + n - j (or: m - i + n - j + 1)  
    if n >= m, we just need to set: i = 0 ~ m, j = (m + n + 1)/2 - i  
(2) B[j-1] <= A[i] and A[i-1] <= B[j]  

ps.1 For simplicity, I presume A[i-1],B[j-1],A[i],B[j] are always valid even if i=0/i=m/j=0/j=n . I will talk about how to deal with these edge values at last.

ps.2 Why n >= m? Because I have to make sure j is non-nagative since 0 <= i <= m and j = (m + n + 1)/2 - i. If n < m , then j may be nagative, that will lead to wrong result.

So, all we need to do is:

Searching i in [0, m], to find an object `i` that:  
    B[j-1] <= A[i] and A[i-1] <= B[j], ( where j = (m + n + 1)/2 - i )  

And we can do a binary search following steps described below:

<1> Set imin = 0, imax = m, then start searching in [imin, imax]  

<2> Set i = (imin + imax)/2, j = (m + n + 1)/2 - i  

<3> Now we have len(left_part)==len(right_part). And there are only 3 situations  
     that we may encounter:  
    <a> B[j-1] <= A[i] and A[i-1] <= B[j]  
        Means we have found the object `i`, so stop searching.  
    <b> B[j-1] > A[i]  
        Means A[i] is too small. We must `ajust` i to get `B[j-1] <= A[i]`.  
        Can we `increase` i?  
            Yes. Because when i is increased, j will be decreased.  
            So B[j-1] is decreased and A[i] is increased, and `B[j-1] <= A[i]` may  
            be satisfied.  
        Can we `decrease` i?  
            `No!` Because when i is decreased, j will be increased.  
            So B[j-1] is increased and A[i] is decreased, and B[j-1] <= A[i] will  
            be never satisfied.  
        So we must `increase` i. That is, we must ajust the searching range to  
        [i+1, imax]. So, set imin = i+1, and goto <2>.  
    <c> A[i-1] > B[j]  
        Means A[i-1] is too big. And we must `decrease` i to get `A[i-1]<=B[j]`.  
        That is, we must ajust the searching range to [imin, i-1].  
        So, set imax = i-1, and goto <2>.  

When the object i is found, the median is:

max(A[i-1], B[j-1]) (when m + n is odd)  
or (max(A[i-1], B[j-1]) + min(A[i], B[j]))/2 (when m + n is even)  

Now let's consider the edges values i=0,i=m,j=0,j=n where A[i-1],B[j-1],A[i],B[j] may not exist. Actually this situation is easier than you think.

What we need to do is ensuring that max(left_part) <= min(right_part). So, if i and j are not edges values(means A[i-1],B[j-1],A[i],B[j] all exist), then we must check both B[j-1] <= A[i] and A[i-1] <= B[j]. But if some of A[i-1],B[j-1],A[i],B[j] don't exist, then we don't need to check one(or both) of these two conditions. For example, if i=0, then A[i-1] doesn't exist, then we don't need to check A[i-1] <= B[j]. So, what we need to do is:

Searching i in [0, m], to find an object `i` that:  
    (j == 0 or i == m or B[j-1] <= A[i]) and  
    (i == 0 or j == n or A[i-1] <= B[j])  
    where j = (m + n + 1)/2 - i  

And in a searching loop, we will encounter only three situations:

<a> (j == 0 or i == m or B[j-1] <= A[i]) and  
    (i == 0 or j = n or A[i-1] <= B[j])  
    Means i is perfect, we can stop searching.  

<b> j > 0 and i < m and B[j - 1] > A[i]  
    Means i is too small, we must increase it.  

<c> i > 0 and j < n and A[i - 1] > B[j]  
    Means i is too big, we must decrease it.  

Thank @Quentin.chen , him pointed out that: i < m ==> j > 0 and i > 0 ==> j < n . Because:

m <= n, i < m ==> j = (m+n+1)/2 - i > (m+n+1)/2 - m >= (2*m+1)/2 - m >= 0      
m <= n, i > 0 ==> j = (m+n+1)/2 - i < (m+n+1)/2 <= (2*n+1)/2 <= n  

So in situation \ and \, we don't need to check whether j > 0 and whether j < n.

Below is the accepted code:

     def median(A, B):  
        m, n = len(A), len(B)  
        if m > n:  
            A, B, m, n = B, A, n, m  
        if n == 0:  
            raise ValueError  

        imin, imax, half_len = 0, m, (m + n + 1) / 2  
        while imin <= imax:  
            i = (imin + imax) / 2  
            j = half_len - i  
            if i < m and B[j-1] > A[i]:  
                # i is too small, must increase it  
                imin = i + 1  
            elif i > 0 and A[i-1] > B[j]:  
                # i is too big, must decrease it  
                imax = i - 1  
            else:  
                # i is perfect  

                if i == 0: max_of_left = B[j-1]  
                elif j == 0: max_of_left = A[i-1]  
                else: max_of_left = max(A[i-1], B[j-1])  

                if (m + n) % 2 == 1:  
                    return max_of_left  

                if i == m: min_of_right = B[j]  
                elif j == n: min_of_right = A[i]  
                else: min_of_right = min(A[i], B[j])  

                return (max_of_left + min_of_right) / 2.0

discussion这个版本倒是没有typo的;

后来看了一下, 这个人是14年发的帖子, 然后17年8月才把这个typo给改过来了; 15年有一个人问了这个问题还被三个人抓住一顿怼.

另外, 这个帖子下面有一个有意思的回复:

@yro said in Share my O(log(min(m,n)) solution with explanation:

@MissMary
Thank you so much for this detailed explanation and concise code in python;

I think the post-processing after i, j are founded can be simplified by adding Integer.MIN_VALUE and Integer.MAX_VALUE on either side of the boundaries.

               A_Left         |        A_Right  
-oo  A[0], A[1], ... , A[i-1] | A[i], A[i+1], ... , A[m-1], A[m]  +oo  
               B_Left         |        B_Right  
-oo  B[0], B[1], ... , B[i-1] | B[i], B[i+1], ... , B[m-1], B[m]  +oo

This modification
(1) solves the problem of invalid value in the edge cases: A[- 1] = Integer.MIN_VALUE or A[m] = Integer.MAX_VALUE.
(2) does not affect the position of median, because Integer.MIN_VALUE on the left is paired with Integer.MAX_VALUE on the right.
(3) does not affect the computation of max_left and min_right: suppose i == 0, which means that A_Left is empty. Considering the left part (including A_Left and B_Left) should not be empty, B_Left is not empty. This implies that max_left is finite value. It is the same case as min_right.

Here is my code:

 public double findMedianSortedArrays(int[] A, int[] B) {  
        int m = A.length;  
        int n = B.length;  
        // make sure m <= n  
        if (m > n) return findMedianSortedArrays(B, A);  

        int imin = 0, imax = m;  
        while (imin <= imax) {  
            int i = imin + (imax - imin) / 2;  
            int j = (m + n + 1) / 2 - i;  

            int A_left = i == 0 ? Integer.MIN_VALUE : A[i - 1];  
            int A_right = i == m ? Integer.MAX_VALUE : A[i];  
            int B_left = j == 0 ? Integer.MIN_VALUE : B[j - 1];  
            int B_right = j == n ? Integer.MAX_VALUE : B[j];  

            if (A_left > B_right) {  
                imax = i - 1;  
            } else if (B_left > A_right) {  
                imin = i + 1;  
            } else {  
                int max_left = A_left > B_left ? A_left : B_left;  
                int min_right = A_right > B_right ? B_right : A_right;  
                if ((m + n) % 2 == 1)   
                    return max_left; // # of left_part = # of right_part + 1;  
                else   
                    return (max_left + min_right) / 2.0;  
            }  
        }  
        return -1;  
    }

这个想法还是很有意思的, 一个利用sentinel来解决Corner Case的思路, 也不是没有见到过的, 比如说在merge sorted的问题当中当时就是见到过;

另外注意他这里的这个sentinel的实现方式, 并不是真的直接append上去到input了, 而是用类似前面刚做完的一个题目的做法, 直接是写在循环内部的一个conditional dummy. 当然这个做法有什么具体的好处(相比于直接append), 我暂时还不清楚;

另外, 他最后这个计算结果的地方处理的稍微不如原作者, 因为他这里是很eager的直接就把max和min都算出来了, 而原来的作者是, 只有自己知道了确实需要min的时候, 才计算min;


discussion看来, 这个算法的实现一直以来都被认为很难, 就是因为Corner Case很难处理;

这个是discussion另外一篇非常好的解法:

@stellari said in Very concise O(log(min(M,N))) iterative solution with detailed explanation:

This problem is notoriously hard to implement due to all the corner cases. Most implementations consider odd-lengthed and even-lengthed arrays as two different cases and treat them separately. As a matter of fact, with a little mind twist. These two cases can be combined as one, leading to a very simple solution where (almost) no special treatment is needed.

First, let's see the concept of 'MEDIAN' in a slightly unconventional way. That is:

"if we cut the sorted array to two halves of EQUAL LENGTHS, then
median is the AVERAGE OF Max(lower_half) and Min(upper_half), i.e. the
two numbers immediately next to the cut
".

For example, for [2 3 5 7], we make the cut between 3 and 5:

[2 3 / 5 7]  

then the median = (3+5)/2. Note that I'll use '/' to represent a cut, and (number / number) to represent a cut made through a number in this article.

for [2 3 4 5 6], we make the cut right through 4 like this:

[2 3 (4/4) 5 7]

Since we split 4 into two halves, we say now both the lower and upper subarray contain 4. This notion also leads to the correct answer: (4 + 4) / 2 = 4;

For convenience, let's use L to represent the number immediately left to the cut, and R the right counterpart. In [2 3 5 7], for instance, we have L = 3 and R = 5, respectively.

We observe the index of L and R have the following relationship with the length of the array N:

N        Index of L / R  
1               0 / 0  
2               0 / 1  
3               1 / 1    
4               1 / 2        
5               2 / 2  
6               2 / 3  
7               3 / 3  
8               3 / 4  

It is not hard to conclude that index of L = (N-1)/2, and R is at N/2. Thus, the median can be represented as

(L + R)/2 = (A[(N-1)/2] + A[N/2])/2  

注意看他这里的思路, 就是很简单的, 首先知道你想要知道什么: what's the index of median; 然后直接就举例子然后观察总结规律; 这个是CS guy最需要培养的一种思考和解决问题的方式;

另外, 在dinitz的答案里面, 用到了这两个概念:

med_low = floor ((n + 1) / 2)  
med_high = ceil ((n + 1) / 2) = floor ((n + 2) / 2)

这个跟上面的结论其实是一样的, 因为dinitz写的东西都是用的1base的index, 而这里这个OP跟大部分的LeetCode的惯例一样, 用的是0based的;

然后他这个下面的算式, 在dinitz的solution里面, 也有完全一致的写法:

medA = ((A[mid_low] + A[mid_high]) / 2)

当然, 这里用的还是1-based的, 但是跟他上面那个结论是对应起来了;

不对, 后来又想了一下, 他这里的这个计算放傻逼得到的结果并不是正确的0based下面的结果, 多差了一个1; 感觉其实是因为他这里这个定义的方式确定也是不同; anyway:

所以这题我们要记住两种计算median的思路:

  • 一种就是类似editorial里面, 完全的利用length的方式计算: 左半段的长度是(N + 1) / 2, 然后比较左右两半段的最值就行了;
  • 另外一种就是完全的利用index的方式计算: 这个解法跟dinitz的解法用的都是这个思路; 具体的结果公式是用类似总结规律的方式得到的, 但是你要知道怎么计算med_lowmed_high, 尤其是0base的时候的计算方式;

To get ready for the two array situation, let's add a few imaginary 'positions' (represented as #'s) in between numbers, and treat numbers as 'positions' as well.

[6 9 13 18]  ->   [# 6 # 9 # 13 # 18 #]    (N = 4)  
position index     0 1 2 3 4 5  6 7  8     (N_Position = 9)  

[6 9 11 13 18]->   [# 6 # 9 # 11 # 13 # 18 #]   (N = 5)  
position index      0 1 2 3 4 5  6 7  8 9 10    (N_Position = 11)  

As you can see, there are always exactly 2N+1 'positions' regardless of length N. Therefore, the middle cut should always be made on the Nth position (0-based). Since index(L) = (N-1)/2 and index(R) = N/2 in this situation, we can infer that *index(L) = (CutPosition-1)/2, index(R) = (CutPosition)/2.

这个思路还是很有意思的, 强行的可以将odd or even length的都构造成odd length, 然后后面就好处理了; 另外, 2 * N + 1应该是很明显的: N + N + 1, always odd;

最后加粗的这个结果是直接根据上面的L和R的计算公式得到的;


Now for the two-array case:

A1: [# 1 # 2 # 3 # 4 # 5 #]    (N1 = 5, N1_positions = 11)  

A2: [# 1 # 1 # 1 # 1 #]     (N2 = 4, N2_positions = 9)  

Similar to the one-array problem, we need to find a cut that divides the two arrays each into two halves such that

"any number in the two left halves" <= "any number in the two right
halves".

We can also make the following observations:

  1. There are 2N1 + 2N2 + 2 position altogether. Therefore, there must be exactly N1 + N2 positions on each side of the cut, and 2 positions directly on the cut.

  2. Therefore, when we cut at position C2 = K in A2, then the cut position in A1 must be C1 = N1 + N2 - k. For instance, if C2 = 2, then we must have C1 = 4 + 5 - C2 = 7.

     [# 1 # 2 # 3 # (4/4) # 5 #]      
    
     [# 1 / 1 # 1 # 1 #]     
    
  3. When the cuts are made, we'd have two L's and two R's. They are

     L1 = A1[(C1-1)/2]; R1 = A1[C1/2];  
     L2 = A2[(C2-1)/2]; R2 = A2[C2/2];  
    

In the above example,

    L1 = A1[(7-1)/2] = A1[3] = 4; R1 = A1[7/2] = A1[3] = 4;  
    L2 = A2[(2-1)/2] = A2[0] = 1; R2 = A1[2/2] = A1[1] = 1;  

Now how do we decide if this cut is the cut we want? Because L1, L2 are the greatest numbers on the left halves and R1, R2 are the smallest numbers on the right, we only need

L1 <= R1 && L1 <= R2 && L2 <= R1 && L2 <= R2  

to make sure that any number in lower halves <= any number in upper halves. As a matter of fact, since
L1 <= R1 and L2 <= R2 are naturally guaranteed because A1 and A2 are sorted, we only need to make sure:

L1 <= R2 and L2 <= R1.

Now we can use simple binary search to find out the result.

If we have L1 > R1, it means there are too many large numbers on the left half of A1, then we must move C1 to the left (i.e. move C2 to the right);   
If L2 > R1, then there are too many large numbers on the left half of A2, and we must move C2 to the left.  
Otherwise, this cut is the right one.   
After we find the cut, the medium can be computed as (max(L1, L2) + min(R1, R2)) / 2;  

这个作者好像是一个Stefan都尊敬的大神, 这个解法也是各种技巧, 不过总体来说还是太复杂了;

Two side notes:

A. Since C1 and C2 can be mutually determined from each other, we can just move one of them first, then calculate the other accordingly. However, it is much more practical to move C2 (the one on the shorter array) first. The reason is that on the shorter array, all positions are possible cut locations for median, but on the longer array, the positions that are too far left or right are simply impossible for a legitimate cut. For instance, [1], [2 3 4 5 6 7 8]. Clearly the cut between 2 and 3 is impossible, because the shorter array does not have that many elements to balance out the [3 4 5 6 7 8] part if you make the cut this way. Therefore, for the longer array to be used as the basis for the first cut, a range check must be performed. It would be just easier to do it on the shorter array, which requires no checks whatsoever. Also, moving only on the shorter array gives a run-time complexity of O(log(min(N1, N2))) (edited as suggested by @baselRus)

B. The only edge case is when a cut falls on the 0th(first) or the 2Nth(last) position. For instance, if C2 = 2N2, then R2 = A2[2*N2/2] = A2[N2], which exceeds the boundary of the array. To solve this problem, we can imagine that both A1 and A2 actually have two extra elements, INT_MAX at A[-1] and INT_MAX at A[N]. These additions don't change the result, but make the implementation easier: If any L falls out of the left boundary of the array, then L = INT_MIN, and if any R falls out of the right boundary, then R = INT_MAX.

这个是在另外一个帖子里面已经被提到过的sentinel的技巧;


I know that was not very easy to understand, but all the above reasoning eventually boils down to the following concise code:

     double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {  
        int N1 = nums1.size();  
        int N2 = nums2.size();  
        if (N1 < N2) return findMedianSortedArrays(nums2, nums1); // Make sure A2 is the shorter one.  

        int lo = 0, hi = N2 * 2;  
        while (lo <= hi) {  
            int mid2 = (lo + hi) / 2;   // Try Cut 2   
            int mid1 = N1 + N2 - mid2;  // Calculate Cut 1 accordingly  

            double L1 = (mid1 == 0) ? INT_MIN : nums1[(mid1-1)/2];    // Get L1, R1, L2, R2 respectively  
            double L2 = (mid2 == 0) ? INT_MIN : nums2[(mid2-1)/2];  
            double R1 = (mid1 == N1 * 2) ? INT_MAX : nums1[(mid1)/2];  
            double R2 = (mid2 == N2 * 2) ? INT_MAX : nums2[(mid2)/2];  

            if (L1 > R2) lo = mid2 + 1;       // A1's lower half is too big; need to move C1 left (C2 right)  
            else if (L2 > R1) hi = mid2 - 1;  // A2's lower half too big; need to move C2 left.  
            else return (max(L1,L2) + min(R1, R2)) / 2;   // Otherwise, that's the right cut.  
        }  
        return -1;  
    }

If you have any suggestions to make the logic and implementation even more cleaner. Please do let me know!

这个人还是很厉害的, 注意他在检测到长度不满意的时候用这个recursion来完成一个swap, 这个我看editorial的时候也想到了;

最后写出来的代码实际上最后跟editorial的解法是很类似的;

但是好像还是有区别, 比如:

@msai06 said in Very concise O(log(min(M,N))) iterative solution with detailed explanation:

can anyone explain me........why we need to assign hi=n2*2

这个是因为他这里找的cut的位置是在他自己的假想array的基础上找的;

@jdneo said in Very concise O(log(min(M,N))) iterative solution with detailed explanation:

@ msai06 because the array is 0-based, since the array has 2N+1 positions, the upper bound is 2N.

而且注意到他用mid2计算mid1的时候, 没有+1了; again, 计算方法不同;


discussion另外一个转化为kth的结果:

@vaputa said in Share my simple O(log(m+n)) solution for your reference:

Binary search. Call 2 times getkth and k is about half of (m + n). Every time call getkth can reduce the scale k to its half. So the time complexity is log(m + n).

    class Solution {  
    public:  
        int getkth(int s[], int m, int l[], int n, int k){  
            // let m <= n  
            if (m > n)   
                return getkth(l, n, s, m, k);  
            if (m == 0)  
                return l[k - 1];  
            if (k == 1)  
                return min(s[0], l[0]);  

            int i = min(m, k / 2), j = min(n, k / 2);  
            if (s[i - 1] > l[j - 1])  
                return getkth(s, m, l + j, n - j, k - j);  
            else  
                return getkth(s + i, m - i, l, n, k - i);  
            return 0;  
        }  

        double findMedianSortedArrays(int A[], int m, int B[], int n) {  
            int l = (m + n + 1) >> 1;  
            int r = (m + n + 2) >> 1;  
            return (getkth(A, m ,B, n, l) + getkth(A, m, B, n, r)) / 2.0;  
        }  
    };

首先, 注意他这里的arg5实际上是一个1based的index, 不信你看他getkthk==1的base case; 所以他在主函数里面的计算方式, 其实就是上面讲过的dinitz用的方法: med = ((A[med_low] + A[med_high])/ 2.0), 注意, 这个除2不是一个整数除法! 这个也是dinitz的算法里面对于floor和ceil使用很频繁的原因: 他写的代码, 默认的就是正常除法而不是整数除法;

另外这里他的m和n是0based的, 也就是是end index, 而不是length; 这个是看下面的java version的时候看出来的; 不对, 是length;

这个算法倒是好理解很多? 当然前提是你要对于kth算法熟悉;

注意他在recursive call里面, 传进去的l + j这样的东西, 是一个标准的start_index + offset的寻址模式; //分不清m和n的base的时候, 就看他们是怎么被用的, 对应的就是看i和j怎么被use的, 然后看到了if (s[i - 1] > l[j - 1])这一行;

另外, 对于这个kth算法的实质, 实际上是有点不理解的, 自己笔算了一下, 大概是理解思路了; 一个小问题是, 这个recursion函数为什么一定要保证两个array的顺序关系? 感觉后面其实没有专门的m和n之间直接比较的运算? 这个不是很想的通; 想自己submit一下也不行, LeetCode好像把这个题目的interface改了;

关于median计算:

@oneone said in Share my simple O(log(m+n)) solution for your reference:

could you explain int l = (m + n + 1) >> 1; int r = (m + n + 2) >> 1; ?

这个是作者自己的解释:

@vaputa said in Share my simple O(log(m+n)) solution for your reference:

get the medians. l and r are the medians.
when n + m is even, l + 1 == r
when n + m is odd, l == r
we can merge 2 cases

不过下面有人指出, 这个可以优化:

@doublefish said in Share my simple O(log(m+n)) solution for your reference:

so when l == r, you will run it twice.

作者也同意这个是可以优化掉的; 这个东西要稍微注意一下, 不要陷死在dinitz的做法的套路里面了; 这种优化实际上面试的时候是有可能被问到的;

一个java version:

@mitulshr said in Share my simple O(log(m+n)) solution for your reference:

Anyone waiting for Java version??

    public class Solution {  
        public double findMedianSortedArrays(int[] nums1, int[] nums2) {  
            //Get the middle element  
            int mid = (nums2.length+nums1.length+1)/2;  
            //Find that middle element  
            double res = getkth(nums1, nums2, mid);  
            //If the combined length is even then find mid+1 element as well  
            if((nums2.length+nums1.length) % 2 == 0) {  
                res += getkth(nums1, nums2, mid+1);  
                //Find the average of two elements  
                res = res/2;  
            }  
            return res;  
        }  
        public int getkth(int[] A, int[] B, int k) {  
            //Make sure array A is the smaller array  
            if(B.length < A.length ) return getkth(B, A, k);  
            //If smaller array is empty, simply return the value from second array  
            if(A.length == 0) return B[k-1];  
            //If k is 1, then it must be the smaller of first element of the array  
            if(k == 1) return Math.min(A[0], B[0]);  

            //Get the index for array A to compare  
            int i = Math.min((A.length), k/2);  
            //Index for array B must be such that i + j = k  
            int j = k - i;  

            //Remove the smaller elemets from the array A if, ith index of A is smaller than jth index of B  
            if(A[i- 1] <  B[j-1]) {  
                int[] newA = new int[A.length - i];  
                //Make a new array and copy the rest of the array elements  
                System.arraycopy(A, i, newA, 0, (A.length - i));  
                return getkth(newA, B, k - i);  
            }  
            else {  
                int[] newB = new int[B.length - j];  
                System.arraycopy(B, j, newB, 0, (B.length - j));  
                return getkth(A, newB, k - j);  
            }  
        }  
    }

这里在计算j的时候用了一个小技巧, 跟这个提议有点类似:

@qinlei515 said in Share my simple O(log(m+n)) solution for your reference:

I think j = min(n, k / 2) is not necessary. Because if k/2 > n then we must have k/2 > m also, since m <= n.

Then in this case we cannot have the expected value since m+n < k. According to this algorithm it should not happen.

And I change it to just j = k/2 the solution also accepted.

但是他这个写的很有问题:

@sculd said in Share my simple O(log(m+n)) solution for your reference:

Copying an array each iteration takes O(n) time. Keeping head position for l and s would be nicer.

事实上, 在模仿c代码的时候, 是要注意一个问题, OP原来的那个kth的参数, 有一个array name, 实际上就是一个array的start index/address; 后面也就有人敏锐的观察到了这个本质, 然后重新谢写了一个更好的java version:

@jimmyzzxhlh said in Share my simple O(log(m+n)) solution for your reference:

Another java version without array copy. Should be easy to understand. Thanks for sharing the solution!

    public double findMedianSortedArrays(int[] nums1, int[] nums2) {  
      int n = nums1.length;  
      int m = nums2.length;  
      int left = (n + m + 1) / 2;  
      int right = (n + m + 2) / 2;  
      return (getKth(nums1, 0, n - 1, nums2, 0, m - 1, left) + getKth(nums1, 0, n - 1, nums2, 0, m - 1, right)) * 0.5;    
    }  

    private int getKth(int[] nums1, int start1, int end1, int[] nums2, int start2, int end2, int k) {  
      int len1 = end1 - start1 + 1;  
      int len2 = end2 - start2 + 1;  
      if (len1 > len2) return getKth(nums2, start2, end2, nums1, start1, end1, k);  
      if (len1 == 0) return nums2[start2 + k - 1];  
      if (k == 1) return Integer.min(nums1[start1], nums2[start2]);  

      int i = start1 + Integer.min(len1, k / 2) - 1;  
      int j = start2 + Integer.min(len2, k / 2) - 1;  
      //Eliminate half of the elements from one of the smaller arrays  
      if (nums1[i] > nums2[j]) {  
          return getKth(nums1, start1, end1, nums2, j + 1, end2, k - (j - start2 + 1));  
      }  
      else {  
          return getKth(nums1, i + 1, end1, nums2, start2, end2, k - (i - start1 + 1));  
      }  
    }

这个也是java模拟指针的一个常见思路了, 在涉及到array的指针操作的时候, 直接维护int的index;

他这个改写一个比较奇怪的地方是他这个i和j的计算方式跟原来的作者不太一样, 这个是我重新改写的方式:

class Solution {  
  public double findMedianSortedArrays(int[] nums1, int[] nums2) {  
    int n = nums1.length;  
    int m = nums2.length;  
    int left = (n + m + 1) / 2;  
    int right = (n + m + 2) / 2;  
    return (getKth(nums1, 0, n - 1, nums2, 0, m - 1, left) + getKth(nums1, 0, n - 1, nums2, 0, m - 1, right)) * 0.5;    
  }  

  private int getKth(int[] nums1, int start1, int end1, int[] nums2, int start2, int end2, int k) {  
    int len1 = end1 - start1 + 1;  
    int len2 = end2 - start2 + 1;  
    if (len1 > len2) return getKth(nums2, start2, end2, nums1, start1, end1, k);  
    if (len1 == 0) return nums2[start2 + k - 1];  
    if (k == 1) return Integer.min(nums1[start1], nums2[start2]);  

    int i = Integer.min(len1, k / 2), j = Integer.min(len2, k / 2);  
    if (nums1[start1 + i - 1] > nums2[start2 + j - 1])  
      return getKth(nums1, start1, end1, nums2, start2 + j, end2, k - j);  
    else  
      return getKth(nums1, start1 + i, end1, nums2, start2, end2, k - i);  
  }  
}

这个就跟作者原来的版本的计算方式更加吻合了, 基本上就是要理解index变量和length变量之间的互动关系;


discussion另外一个解法:

@tyuan73 said in Share my iterative solution with O(log(min(n, m))):

This is my iterative solution using binary search. The main idea is to find the approximate location of the median and compare the elements around it to get the final result.

  1. do binary search. suppose the shorter list is A with length n. the runtime is O(log(n)) which means no matter how large B array is, it only depends on the size of A. It makes sense because if A has only one element while B has 100 elements, the median must be one of A[0], B[49], and B[50] without check everything else. If A[0] <= B[49], B[49] is the answer; if B[49] < A[0] <= B[50], A[0] is the answer; else, B[50] is the answer.

  2. After binary search, we get the approximate location of median. Now we just need to compare at most 4 elements to find the answer. This step is O(1).

  3. the same solution can be applied to find kth element of 2 sorted arrays.

Here is the code:

        public double findMedianSortedArrays(int A[], int B[]) {  
        int n = A.length;  
        int m = B.length;  
        // the following call is to make sure len(A) <= len(B).  
        // yes, it calls itself, but at most once, shouldn't be  
        // consider a recursive solution  
        if (n > m)  
            return findMedianSortedArrays(B, A);  

        // now, do binary search  
        int k = (n + m - 1) / 2;  
        int l = 0, r = Math.min(k, n); // r is n, NOT n-1, this is important!!  
        while (l < r) {  
            int midA = (l + r) / 2;  
            int midB = k - midA;  
            if (A[midA] < B[midB])  
                l = midA + 1;  
            else  
                r = midA;  
        }  

        // after binary search, we almost get the median because it must be between  
        // these 4 numbers: A[l-1], A[l], B[k-l], and B[k-l+1]   

        // if (n+m) is odd, the median is the larger one between A[l-1] and B[k-l].  
        // and there are some corner cases we need to take care of.  
        int a = Math.max(l > 0 ? A[l - 1] : Integer.MIN_VALUE, k - l >= 0 ? B[k - l] : Integer.MIN_VALUE);  
        if (((n + m) & 1) == 1)  
            return (double) a;  

        // if (n+m) is even, the median can be calculated by   
        //      median = (max(A[l-1], B[k-l]) + min(A[l], B[k-l+1]) / 2.0  
        // also, there are some corner cases to take care of.  
        int b = Math.min(l < n ? A[l] : Integer.MAX_VALUE, k - l + 1 < m ? B[k - l + 1] : Integer.MAX_VALUE);  
        return (a + b) / 2.0;  
    }

I'm lazy to type. But I found a very good pdf to explain my algorithm: http://ocw.alfaisal.edu/NR/rdonlyres/Electrical-Engineering-and-Computer-Science/6-046JFall-2005/30C68118-E436-4FE3-8C79-6BAFBB07D935/0/ps9sol.pdf

BTW: Thanks to xdxiaoxin. I've removed the check "midB > k".

注意他这里的k和r之间的base的关系; 严格来说, 这两个这里都是0-based的; r是一个exclusive的端点, 因为下面的binary search是用的开区间的写法; 而k的计算, 你可以看下面midB的计算, 可以...有点奇怪;

我自己实验的时候也发现了这个问题:

@shen5630 said in Share my iterative solution with O(log(min(n, m))):

There is no need to define r = Math.min(k, n), because n is always less than or equal to k.
r=n is enough.

改了之后确实也能AC; 但是注意这个改动的其实只是search space; k本身还是需要的;

加上这个改动之后, 这个算法其实就跟editorial的解法非常的类似了, 只不过一些细节上有些变化而已;

如果对于细节不理解: 这里的l和r实际上都是一个0based的index; r看起来像是1based, 是因为r是一个exclusive的index; 而k本身是一个length的概念; 如果按照editorial的概念, 你可以这样理解: 你看他最后给的那四个candidate, 两个left part的末尾实际上是[l - 1] and [k - l], 所以:

l - 1 = i - 1  
k - l = j - 1  
k - 1 = i + j - 2  
k = i + j - 1 = (m + n + 1) / 2 - 1 = (m + n - 1) / 2

所以最后的结果实际上跟editorial是一样的;

这样这个对等性是理解了, 不过他这里这些变量设定实际对应的到底是什么意思呢?

@Cat818 said in Share my iterative solution with O(log(min(n, m))):

r = Math.min(k, n); actually r can be directly set to n since n is always smaller than m. Hopefully r = n would help understand better. Then the while loop helps locate the mid position of A.

他这样说对吗, 因为我们最后得到的结果实际上是k + 1 = len (left_part), 所以k本身代表的到底是什么? 按照他的代码, medA + medB = k, 这个讲不通啊, 这个代表什么啊? 感觉用find index (rather than length as in Editorial)的思路写出来的这个代码有点难以理解;

比如他这个算法, 最后结束的时候l == r == medA, 但是他最后停止的时候实际上是:

[l - 1] | [l]  
[k - l] | [k - l + 1]

所以实际上他定义的medA是在cut_A的右边, 但是他定义的med_B却是在cut_B的左边; 这种定义看起来就很奇怪;

我感觉他就是一开始逻辑就没有理清楚, 然后随便写, 然后最后debug调到数学正确而已, 背后的理念我个人感觉是有点讲不通的;


submission一个比我快很多的解法;

class Solution {  
    public double findMedianSortedArrays(int[] nums1, int[] nums2) {  
        int n1=nums1.length;  
        int n2=nums2.length;  
        if(n1>n2) {  
            return findMedianSortedArrays(nums2,nums1);  
        }  
        if(n1 == 0)  
            return (n2%2==0)?(double)(nums2[n2/2]+nums2[(n2/2)-1])/2 : nums2[n2/2];  
        int l=0,r=n1;  

        while(l<=r) {  
            int partX = (l+r)/2,  
            partY=((n1+n2+1)/2)-partX;  

            int leftX = (partX==0)?Integer.MIN_VALUE:nums1[partX-1];  
            int rightX = (partX==n1)?Integer.MAX_VALUE:nums1[partX];  
            int leftY = (partY==0)?Integer.MIN_VALUE:nums2[partY-1];  
            int rightY = (partY==n2)?Integer.MAX_VALUE:nums2[partY];  

            if(leftX <= rightY && leftY <= rightX) {  
               if((n1+n2) % 2 == 0) {  
                    int res = Math.max(leftX,leftY);  
                    res+=Math.min(rightX,rightY);  
                    return (double)res/2;  
                }  
                else  
                    return Math.max(leftX,leftY);  
            } else if(leftX > rightY) {  
                r=partX-1;  
            } else   
                l=partX+1;  
        }  
       return -1;  
    }  
}

其实感觉差不多, 不过加了一个特殊的base case: shorter array的长度是0的特殊情况; 感觉也就是讨好OJ;

这个题目还是重点学会背后的思路;


Problem Description

There are two sorted arrays nums1 and nums2 of size m and n respectively.

Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).

Example 1:

nums1 = [1, 3]  
nums2 = [2]  

The median is 2.0

Example 2:

nums1 = [1, 2]  
nums2 = [3, 4]  

The median is (2 + 3)/2 = 2.5

Difficulty:Hard
Total Accepted:225.7K
Total Submissions:997.7K
Contributor:LeetCode
Companies
googlemicrosoftapplezenefitsyahooadobedropbox
Related Topics
arraybinary searchdivide and conquer

results matching ""

    No results matching ""