FourSum [source code]
public class FourSum {
static
/******************************************************************************/
public class Solution {
public List<List<Integer>> fourSum(int[] nums, int target) {
List<List<Integer>> res = new ArrayList<>();
if (nums.length < 4)
return res;
Arrays.sort(nums);
for (int i = 0; i < nums.length - 3; i++) {
if (4 * nums[i] > target)
break;
if (i > 0 && nums[i] == nums[i - 1])
continue;
for (int j = i + 1; j < nums.length - 2; j++) { //[n - 3], [n - 2], [n - 1] is the last iteration
if (3 * nums[j] > target - nums[i])
break;
if (j > i + 1 && nums[j] == nums[j - 1])
continue;
int lo = j + 1, hi = nums.length - 1, goal = target - nums[i] - nums[j];
while (lo < hi) {
if (nums[lo] + nums[hi] == goal) {
res.add(Arrays.asList(nums[i], nums[j], nums[lo], nums[hi]));
while (lo < hi && nums[lo] == nums[lo + 1])
lo++;
while (lo < hi && nums[hi] == nums[hi - 1])
hi--;
lo++;
hi--;
} else if (nums[lo] + nums[hi] < goal) {
lo++;
} else hi--;
}
}
}
return res;
}
}
/******************************************************************************/
public static void main(String[] args) {
FourSum.Solution tester = new FourSum.Solution();
int[][] inputs = {
{1,-2,-5,-4,-3,3,3,5}, {-11},
{1, 0, -1, 0, -2, 2}, {0},
};
int[][][] answers = {
{{-5,-4,-3,1}},
{
{-1, 0, 0, 1},
{-2, -1, 1, 2},
{-2, 0, 0, 2}
},
};
for (int i = 0; i < inputs.length / 2; i++) {
int[] nums = inputs[2 * i];
int target = inputs[1 + 2 * i][0];
int[][] ans = answers[i];
System.out.println(Printer.separator());
int[][] output = Matrix.listToMatrix(tester.fourSum(nums, target));
String ansStr = Matrix.printMatrix(ans);
String outputStr = Matrix.printMatrix(output);
System.out.println(
Printer.wrapColor(Printer.array(nums) + " AND " + target, "magenta") +
" -> \n" + outputStr +
Printer.wrapColor(", expected: \n" + ansStr, ansStr.equals(outputStr) ? "green" : "red")
);
}
}
}
again, array的题目, 如果你没有把握找到O(N)的解法, 那么不如就直接先sort一下降低难度;
这个算法是直接通过3sum算法转化过来的; 最后的速度是48ms (72%), 还算可以接受; 整个转化的过程中其实没有什么很特别的东西, 只有一点需要注意: 3sum问题的时候, 当nums[i] > 0的时候, 我们break, 那么在这个问题, 当target不再限定于0的时候, 这个临界值应该是多少呢? 刚开始不好想这个点, 应该回过头到3sum问题的情景里面去, 问自己, 为什么我们到了0就可以break, 底下深层的原理是什么?
看了一下submission其他的最优解, 背后的原理实际上都差不多, 基本上都是3sum的变形, 无非是跟当时3sum问题的时候一样, 有很多千奇百怪的写法;
这个是discussion最优解, 虽然很长, 但是有值得学习的东西:
public List<List<Integer>> fourSum(int[] nums, int target) {
ArrayList<List<Integer>> res = new ArrayList<List<Integer>>();
int len = nums.length;
if (nums == null || len < 4)
return res;
Arrays.sort(nums);
int max = nums[len - 1];
if (4 * nums[0] > target || 4 * max < target)
return res;
int i, z;
for (i = 0; i < len; i++) {
z = nums[i];
if (i > 0 && z == nums[i - 1])// avoid duplicate
continue;
if (z + 3 * max < target) // z is too small
continue;
if (4 * z > target) // z is too large
break;
if (4 * z == target) { // z is the boundary
if (i + 3 < len && nums[i + 3] == z)
res.add(Arrays.asList(z, z, z, z));
break;
}
threeSumForFourSum(nums, target - z, i + 1, len - 1, res, z);
}
return res;
}
// Find all possible distinguished three numbers adding up to the target
// in sorted array nums[] between indices low and high. If there are,
// add all of them into the ArrayList fourSumList, using
// fourSumList.add(Arrays.asList(z1, the three numbers))
public void threeSumForFourSum(int[] nums, int target, int low, int high, ArrayList<List<Integer>> fourSumList,
int z1) {
if (low + 1 >= high)
return;
int max = nums[high];
if (3 * nums[low] > target || 3 * max < target)
return;
int i, z;
for (i = low; i < high - 1; i++) {
z = nums[i];
if (i > low && z == nums[i - 1]) // avoid duplicate
continue;
if (z + 2 * max < target) // z is too small
continue;
if (3 * z > target) // z is too large
break;
if (3 * z == target) { // z is the boundary
if (i + 1 < high && nums[i + 2] == z)
fourSumList.add(Arrays.asList(z1, z, z, z));
break;
}
twoSumForFourSum(nums, target - z, i + 1, high, fourSumList, z1, z);
}
}
// Find all possible distinguished two numbers adding up to the target
// in sorted array nums[] between indices low and high. If there are,
// add all of them into the ArrayList fourSumList, using
// fourSumList.add(Arrays.asList(z1, z2, the two numbers))
public void twoSumForFourSum(int[] nums, int target, int low, int high, ArrayList<List<Integer>> fourSumList,
int z1, int z2) {
if (low >= high)
return;
if (2 * nums[low] > target || 2 * nums[high] < target)
return;
int i = low, j = high, sum, x;
while (i < j) {
sum = nums[i] + nums[j];
if (sum == target) {
fourSumList.add(Arrays.asList(z1, z2, nums[i], nums[j]));
x = nums[i];
while (++i < j && x == nums[i]) // avoid duplicate
;
x = nums[j];
while (i < --j && x == nums[j]) // avoid duplicate
;
}
if (sum < target)
i++;
if (sum > target)
j--;
}
return;
}
这个算法本身并不难, 除了base case的2sum, 上面两层完成的全都是delegation而已; 这里要学习的是他每一层的时候, 给出了比较多的premature exit, 比我的多两个;
然后下面就开始有人直接总结k-sum了;
Using List
path to store the selected members, instead of passing them separately.
所以就是有点类似backtracking的意思了;
想了一下, 这个list的作用其实不是为了backtracking. backtracking的实现其实可以说是利用startIndex这样的一个int的概念就可以完成; 这个list本身的作用其实有点像recursion问题里面的tag, 就是用来向下传递, 直到触底的;
List<List<Integer>> kSum_Trim(int[] a, int target, int k) {
List<List<Integer>> result = new ArrayList<>();
if (a == null || a.length < k || k < 2) return result;
Arrays.sort(a);
kSum_Trim(a, target, k, 0, result, new ArrayList<>());
return result;
}
void kSum_Trim(int[] a, int target, int k, int start, List<List<Integer>> result, List<Integer> path) {
int max = a[a.length - 1];
if (a[start] * k > target || max * k < target) return;
if (k == 2) { // 2 Sum
int left = start;
int right = a.length - 1;
while (left < right) {
if (a[left] + a[right] < target) left++;
else if (a[left] + a[right] > target) right--;
else {
result.add(new ArrayList<>(path));
result.get(result.size() - 1).addAll(Arrays.asList(a[left], a[right]));
left++; right--;
while (left < right && a[left] == a[left - 1]) left++;
while (left < right && a[right] == a[right + 1]) right--;
}
}
}
else { // k Sum
for (int i = start; i < a.length - k + 1; i++) {
if (i > start && a[i] == a[i - 1]) continue;
if (a[i] + max * (k - 1) < target) continue;
if (a[i] * k > target) break;
if (a[i] * k == target) {
if (a[i + k - 1] == a[i]) {
result.add(new ArrayList<>(path));
List<Integer> temp = new ArrayList<>();
for (int x = 0; x < k; x++) temp.add(a[i]);
result.get(result.size() - 1).addAll(temp); // Add result immediately.
}
break;
}
path.add(a[i]);
kSum_Trim(a, target - a[i], k - 1, i + 1, result, path);
path.remove(path.size() - 1); // Backtracking
}
}
}
如果确实是用backtracking来理解这个问题的话, 那么他之前写的那么多的premature exit其实就可以理解为pruning了;
另外所谓的backtracking算法写起来的主要难度其实就是一个recursion了, 看看他这里的写法, 其实是把2sum当做是一个base case, 然后其他比2大的, 直接都是一个recursive call, 很正经的写法;
注意这里的写法:
left++; right--;
while (left < right && a[left] == a[left - 1]) left++;
while (left < right && a[right] == a[right + 1]) right--;
这个跳子部分的写法跟我之前的写法稍微有点区别, 虽然本质是差不多的; 这个写法其实就有点类似之前3sum的时候很多人直接在header里面用前置++的写法; 这类写法的特点是先++一次然后再进行第一次的header判断, 也就是说其实是从, 比如left来说, 就是从left+1的位置开始判断; 所以这种写法的时候, 你iterate pair用的指针其实是领先指针, 然后你比较的时候要向后和落后指针比较; 其实我还是认为两种写法差别不大, 只要你自己能够对应理解就行了;
另外本能上好像觉得这个问题是不是没有设置合理的强base case? 因为这个其实是一个不断缩短list的recursion, 所以可能一个直觉的想法就是认为base case的判断里面应该加上一个对start的判断; 但是要注意, 这里start跟k其实是对应的; 虽然没有对start的判断, 但是有对于k的判断: 虽然判断的是2而不是常见的0什么的, 但是足够了;
另外注意这里最后一行的path.remove(path.size() - 1);
, 这个应该是不难理解的, 一个backtracking, 你退出的时候是要记得undo的;
这个也是另外一个k-sum的算法, 思路是类似的, 区别在于少写了很多pruning, 然后helper函数改成了有返回值, 而不是直接用Side Effect来做;
public class Solution {
int len = 0;
public List<List<Integer>> fourSum(int[] nums, int target) {
len = nums.length;
Arrays.sort(nums);
return kSum(nums, target, 4, 0);
}
private ArrayList<List<Integer>> kSum(int[] nums, int target, int k, int index) {
ArrayList<List<Integer>> res = new ArrayList<List<Integer>>();
if(index >= len) {
return res;
}
if(k == 2) {
int i = index, j = len - 1;
while(i < j) {
//find a pair
if(target - nums[i] == nums[j]) {
List<Integer> temp = new ArrayList<>();
temp.add(nums[i]);
temp.add(target-nums[i]);
res.add(temp);
//skip duplication
while(i<j && nums[i]==nums[i+1]) i++;
while(i<j && nums[j-1]==nums[j]) j--;
i++;
j--;
//move left bound
} else if (target - nums[i] > nums[j]) {
i++;
//move right bound
} else {
j--;
}
}
} else{
for (int i = index; i < len - k + 1; i++) {
//use current number to reduce ksum into k-1sum
ArrayList<List<Integer>> temp = kSum(nums, target - nums[i], k-1, i+1);
if(temp != null){
//add previous results
for (List<Integer> t : temp) {
t.add(0, nums[i]);
}
res.addAll(temp);
}
while (i < len-1 && nums[i] == nums[i+1]) {
//skip duplicated numbers
i++;
}
}
}
return res;
}
}
注意他这里最后组合的时候的做法, 其实是有点像学OCaml的时候的做法: hd :: f(tl)
这样的组合技巧, 这个还是很直观的; 426这个课感觉学了之后真的是收获还是很多的;
有些解号称比N^3更好, 但是Stefan认为lower bound就是N^3:
Some people say their solutions are O(n2 log n) or even O(n2), but...
Consider cases where nums is the n numbers from 1 to n.
=> There are Θ(n4) different quadruplets (nC4, to be exact, so about n4 / 24).
=> There are Θ(n) possible sums (from 1+2+3+4 to (n-3)+(n-2)+(n-1)+n, so about 4n sums).
=> At least one sum must have Ω(n3) different quadruplets.
=> For that sum, we must generate those Ω(n3) quadruplets.
=> For these cases we have to do Ω(n3) work.
=> O(n2 log n) or even O(n2) are impossible.
这个想法其实是非常好的. 注意, 我们还是要考虑一下一个input case/instance本身代表的是什么. 严格来说是一个nums和一个target, 但是这里为了方便讨论, 他就直接只讨论固定的一个nums, 然后通过不同的target来对应不同的input case; 这个是一个应该掌握的思考方式;
这个是另外一个生成average能做到N^2的算法:
public List<List<Integer>> fourSum(int[] num, int target) {
Arrays.sort(num);
Map<Integer, List<int[]>> twoSumMap = new HashMap<>(); // for holding visited pair sums. All pairs with the same pair sum are grouped together
Set<List<Integer>> res = new HashSet<>(); // for holding the results
for (int i = 0; i < num.length; i++) {
// get rid of repeated pair sums
if (i > 1 && num[i] == num[i - 2]) continue;
for (int j = i + 1; j < num.length; j++) {
// get rid of repeated pair sums
if (j > i + 2 && num[j] == num[j - 2]) continue;
// for each pair sum, check if the pair sum that is needed to get the target has been visited.
if (twoSumMap.containsKey(target - (num[i] + num[j]))) {
// if so, get all the pairs that contribute to this visited pair sum.
List<int[]> ls = twoSumMap.get(target - (num[i] + num[j]));
for (int[] pair : ls) {
// we have two pairs: one is indicated as (pair[0], pair[1]), the other is (i, j).
// we first need to check if they are overlapping with each other.
int m1 = Math.min(pair[0], i); // m1 will always be the smallest index
int m2 = Math.min(pair[1], j); // m2 will be one of the middle two indices
int m3 = Math.max(pair[0], i); // m3 will be one of the middle two indices
int m4 = Math.max(pair[1], j); // m4 will always be the largest index
if (m1 == m3 || m1 == m4 || m2 == m3 || m2 == m4) continue; // two pairs are overlapping, so just ignore this case
res.add(Arrays.asList(num[m1], num[Math.min(m2, m3)], num[Math.max(m2, m3)], num[m4])); // else record the result
}
}
// mark that we have visited current pair and add it to the corrsponding pair sum group.
// here we've encoded the pair indices i and j into an integer array of length 2.
twoSumMap.computeIfAbsent(num[i] + num[j], key -> new ArrayList<>()).add(new int[] {i, j});
}
}
return new ArrayList<List<Integer>>(res);
}
他里面m1..m4这个循环部分的做法其实很简单, 因为他这里这些pair存放的其实是index, 所以这里他最后要做的其实就是一个简单的直接对比; 也就是说两个pair不能有共同的index就行了; 不知道他为什么要做这个m1..m4的东西; 如果只是为了下面add的时候方便一些的时候, 直接把这四个数字做成一个小的array, 然后sort一下就行了, 因为长度为4的sort其实是非常快的; 事实上, 这个sort好像也并无卵用?
注意这里computeIfAbsent的用法:
default V computeIfAbsent(K key,Function<? super K,? extends V> mappingFunction)
If the specified key is not already associated with a value (or is mapped to null), attempts to compute its value using the given mapping function and enters it into this map unless null.
If the function returns null no mapping is recorded. If the function itself throws an (unchecked) exception, the exception is rethrown, and no mapping is recorded. The most common usage is to construct a new object serving as an initial mapped value or memoized result.
下面是作者对于这个算法的复杂度的解释:
Basic idea is to reduce the 4Sum problem to 2Sum one. In order to achieve that, we can use an array (size of n^2) to store the pair sums and this array will act as the array in 2Sum case (Here n is the size of the original 1D array and it turned out that we do not even need to explicitly use the n^2 sized array ). We also use a hashmap to mark if a pair sum has been visited or not (the same as in the 2Sum case). The tricky part here is that we may have multiple pairs that result in the same pair sum. So we will use a list to group these pairs together. For every pair with a particular sum, check if the pair sum that is needed to get the target has been visited. If so, further check if there is overlapping between these two pairs. If not, record the result.
Time complexity to get all the pairs is O(n^2). For each pair, if the pair sum needed to get the target has been visited, the time complexity will be O(k), where k is the maximum size of the lists holding pairs with visited pair sum. Therefore the total time complexity will be O(k*n^2). Now we need to determine the range of k. Basically the more distinct pair sums we get, the smaller k will be. If all the pair sums are different from each other, k will just be 1. However, if we have many repeated elements in the original 1D array, or in some extreme cases such as the elements form an arithmetic progression, k can be of the order of n (strictly speaking, for the repeated elements case, k can go as high as n^2, but we can get rid of many of them). On average k will be some constant between 1 and n for normal elements distribution in the original 1D array. So on average our algorithm will go in O(n^2) but with worst case of O(n^3).
总体来说讲的还是比较有道理的; 他这个算法其实就是我当时做3sum的时候没有找到的, 很单纯的用set来完成一个remove duplicate的过程;
java> Set<List<Integer>> set = new HashSet<>()
java.util.Set<java.util.List<java.lang.Integer>> set = []
java> set.add(Arrays.asList(1,2,3))
java.lang.Boolean res1 = true
java> set
java.util.Set<java.util.List<java.lang.Integer>> set = [[1, 2, 3]]
java> set.add(Arrays.asList(1,2,3))
java.lang.Boolean res2 = false
java> set
java.util.Set<java.util.List<java.lang.Integer>> set = [[1, 2, 3]]
java> List<Integer> ls1 = Arrays.asList(1,2,3)
java.util.List<java.lang.Integer> ls1 = [1, 2, 3]
java> List<Integer> ls2 = Arrays.asList(1,2,3)
java.util.List<java.lang.Integer> ls2 = [1, 2, 3]
java> boolean b = ls1 == ls2
boolean b = false
java> ls1.equals(ls2)
java.lang.Boolean res6 = true
所以他这里在放到res之前才会先要进行一个sort的操作, 这样set的duplicate才有作用, 不然顺序不同但是element相同的两个list可能会被认为是不同的两个list而都被add进来, 这个不是我们想要的结果;
他这个算法整体其实就是对2sum的一个直观改写: 2sum是每次走到一个integer, 就去找另外一个integer, 而这里在4sum里面, 我们转化为, 每次走到一个2sum, 我们去找另外一个2sum. 最后再加一个set来remove duplicate就行了;
这个算法最后的速度非常的慢, 一个原因就是因为, set的这个remove duplicate, 其实并不高效. 比如我们要比较两个list, 因为要调用的是equals, 所以实际上最后每一个元素都要被比较一次, 这样最后花在这个工作上的时间是很大的;
另外他这里还可以稍微有一点小优化, 比如m1, 其实就可以确定是pair[0], 而m3可以确定是i, 因为i是最外层的loop var, 所以按照历史信息流的定义, 之前的这个pair的第一个坐标, 肯定是leq当前的的这个i的; 注意, 这个性质在j身上则不存在;
当你把m1和m3这两个写法改变一下之后:
int m1 = pair[0]; // m1 will always be the smallest index
int m2 = Math.min(pair[1], j); // m2 will be one of the middle two indices
int m3 = i; // m3 will be one of the middle two indices
int m4 = Math.max(pair[1], j); // m4 will always be the largest index
下面的m1 == m3 || m1 == m4 || m2 == m3 || m2 == m4
的比较的含义就比较清晰了, 其实就是(m1,m3)(就是pair自己) and (m2,m4)两个pair之间进行一个坐标的对比; 没有重复就行; 当然, 这个其实只是一个加速的小优化, 如果没有的话, 实际上也是可以做的, 只不过就要浪费更多的时间在set的equals的判断上面; 不对, 这个不是加速优化那么简单, 这个其实是block掉了一部分的false postive; 有些quadruplet是由index 1,2 and 1,3组成的, 这样的quadruplet你是不能放到res里面的, 在res看来他还以为[1]的value在nums出现了两次, 但是实际上并不是这样;
把这些小差别搞懂了之后, 这个算法其实跟2sum就非常的相似了;
唯一还有一个不太好懂的地方就是为什么他这里有一个i > 1 && num[i] == num[i - 2]' and 'j > i + 2 && num[j] == num[j - 2]
的过程?
这个是作者自己给的解释:
One possible source for repeated pair sum is repeated element, say you have elements 0, 0, 0, 1, 2, ..., once we have checked the pair sums for the first 0, we do not want to do that for the second and third one because it will yield the same pair sums. So we simply skip those cases. The reason we have "num[i] == num[i - 2]" instead of "num[i] == num[i - 1]" is to account for special cases such as all the elements are the same.
但是感觉其实并没有解释清楚; 我把这两句给直接comment掉了之后, 结果最后OJ的时间几乎没有变化, 不知道到底是什么意思;
另外关于num[m1], num[Math.min(m2, m3)], num[Math.max(m2, m3)], num[m4]
这一部分, 我一开始以为用sort做可以更快, 实际上并不是这样, JAVA对于小case的sort的优化并没有我想象中的那么好, 我改成这种做法之后, 最后就超时了; //后来发现改了一下, 没有超时了, 我是把作者本来的跳子操作删除之后, 才超时的, 还原之后就好了. 速度其实还是差不多; 不过, 这个也说明作者这个跳子其实是有作用的?
我把跳子改成跳一个, 就出现了一个break的case:
[0,0,0,0]
0
这个最后就全都跳掉了; 所以这个人这么跳, 综合他前面含糊其词的解释看来, 应该是他这个改动其实就是根据OJ的反馈来改动了. 这里也体现了这个解法整体上的落后性了, 如果面试的时候还需要写这种根据OJ的反映来一点一点改的算法, 那么很大可能面试就悲剧了; 另外这个算法本身也是依赖sort的, 所以对比2pointer的算法完全没有优势; 那么实际面试的时候怎么避免想到这个算法呢? 只能说不要思维定式, 脑子里找思路的时候也要有点search tree的感觉, 不要听到2sum就立刻就是Map的那个做法, 这里既然sort了, 其实2pointer明显是更合理的;
当然, 学习一下还是可以的; 尤其是关于复杂度的讨论的部分;
另外, 我改成
if (j > i + 2 && num[j] == num[j - 1]) continue;
这样的跳子, 还是不行, 这个就有点不能理解了, break的case是:
Input:
[0,1,5,0,1,5,5,-4]
11
Output:
[[0,1,5,5]]
Expected:
[[-4,5,5,5],[0,1,5,5]]
自己纸上画了画, 感觉是对的啊?
另外他这个跳两子, 你要理解的是, 完成的其实是控制如果有长度至少等于3的相等的entry, 那么就从第三个开始跳. 如果有长度为2的, 就拉倒了; 这样虽然相当于relax了一点, 放过了一些长度为2的连子, 但是看这里的样子, 好像是必要的?
我知道这个跳一个子的算法错在哪里了, 比如上面这个例子, 最后sort之后是-4,0,0,1,1,5,5,5
, 这个如果按照普通的跳子, 其实找到的就是leftmost occurrence, 那么最后1这个点我们能找到一个(0,5)和(5,6), 这里就有问题的了, 你(5,6)的时候, 按理是应该能够找到一个-4+5=1的, 但是这里因为我们只收集了leftmost的, 所以最后导致有overlapping, 就被跳过了; 他跳两子的操作, 保证了最后实际上长度为2的连子都被保留了, 那么这里对于-4+5=1的pair sum, 就同时保留了(0,5)和(0,6), 后面就可以成功和(5,7)组合;
这个具体背后有没有什么深层次的原理现在还不知道, 不过这个想法还是很好的; 另外这个算法也侧面说明, 算法题有时候如果方向选错了, 真的是越写越难. 这样一个跳两子的小优化, 真正实战的时候还真没把握临场写出来;
Problem Description
Given an array S of n integers, are there elements a, b, c, and d in S such that a + b + c + d = target? Find all unique quadruplets in the array which gives the sum of target.
Note: The solution set must not contain duplicate quadruplets.
For example, given array S = [1, 0, -1, 0, -2, 2], and target = 0.
A solution set is:
[
[-1, 0, 0, 1],
[-2, -1, 1, 2],
[-2, 0, 0, 2]
]
Difficulty:Medium
Total Accepted:124.1K
Total Submissions:463.9K
Contributor: LeetCode
Related Topics
array hash table two pointers
Similar Questions
Two Sum 3Sum 4Sum II