Find the k-th sum from sums of every pair

211 Views Asked by At

I'm given the array with n elements and I need to find k-th sum from sums of every pair n^2 in time complexity O(n*logn), sums are in ascending order.


Example input

In the first line are given number of elements and number of sum to find. In the second line list of number which sums of pair we need to generate.

 3 6
 1 4 6

The answer is 8 for given list, below is array of every pair of sums, where 8, sum of 4+4 is on the 6-th position.

2 5 5 7 7 8 10 10 12

where first three elements are genereted as follow

  • 1+1 = 2
  • 1+4 = 5
  • 4+1 = 5

Edit: I came up to this that the main problem is to find place for sum of elements with themselves. I will give example to make it more clear.

For sequence [1, 4, 10], we have

2 5 5 8 11 11 14 14 20

The problem is where to place sum of 4+4, that depends if 1+10 > 4+4, others sums have fixed place because second element + last will be always bigger than last + first (if we have elements in ascending order).

3

There are 3 best solutions below

0
fafl On BEST ANSWER

Thanks to juvian solving the hard parts, I was able to write this solution, the comments should explain it:

def count_sums_of_at_most(amount, nums1, nums2):

    p1 = 0  # Pointer into the first array, start at the beginning
    p2 = len(nums2) - 1  # Pointer into the second array, start at the end

    # Move p1 up and p2 down, walking through the "diagonal" in O(n)
    sum_count = 0
    while p1 < len(nums1):
        while amount < nums1[p1] + nums2[p2]:
            p2 -= 1
            if p2 < 0:
                # p1 became too large, we are done
                break
        else:
            # Found a valid p2 for the given p1
            sum_count += p2 + 1
            p1 += 1
            continue
        break

    return sum_count

def find_sum(k, nums1, nums2):

    # Sort both arrays, this runs in O(n * log(n))
    nums1.sort()
    nums2.sort()

    # Binary search through all sums, runs in O(n * log(max_sum))
    low = nums1[0] + nums2[0]
    high = nums1[-1] + nums2[-1]
    while low <= high:
        mid = (high + low) // 2
        sum_count = count_sums_of_at_most(mid, nums1, nums2)
        if sum_count >= k:
            high = mid - 1
        else:
            low = mid + 1

    return low

arr = [1, 4, 5, 6]
for k in range(1, 1 + len(arr) ** 2):
    print('sum', k, 'is', find_sum(k, arr, arr))

This prints:

sum 1 is 2
sum 2 is 5
sum 3 is 5
sum 4 is 6
sum 5 is 6
sum 6 is 7
sum 7 is 7
sum 8 is 8
sum 9 is 9
sum 10 is 9
sum 11 is 10
sum 12 is 10
sum 13 is 10
sum 14 is 11
sum 15 is 11
sum 16 is 12
2
Pat Zhang On

Edit: this is O(n^2)

The way I understood the problem, the first number on the first row is the number of numbers, and the second number is k.

You can do this problem by using a PriorityQueue, which orders everything for you as you input numbers. Use 2 nested for loops such that they visit each pair once.

    for(int k = 0; k < n; k++){
        for(int j = 0; j <= k; j++){

If j==k, enter k+j into the PriorityQueue once, if not, enter the sum twice. Then, loop through the PriorityQueue to get he 6th value.

Will edit with full code if you'd like.

4
juvian On

This can be solved in O(n log maxSum).

Pseudocode:

sort(array)
low = array[1] * 2
high = array[n] * 2
while (low <= high): (binarySearch between low and high)
    mid = floor((high + low) / 2)
    qty = checkHowManySumsAreEqualOrLessThan(mid)
    if qty >= k:
        high = mid - 1
    else:
        low = mid + 1
answer = low // low will be the first value of mid where qty was >= k. This means that on low - 1, qty was < k. This means the solution must be low

Sorting is O(n log n). Binary search costs log(array[n] * 2 - array[0] * 2).

checkHowManySumsAreEqualOrLessThan(mid) can be done in O(n) using 2 pointers, let me know if you can't figure out how.

This works because even though we are not doing the binary search over k, it is true that if there were x sums <= mid, if k < x then the kth sum would be lower than mid. Same for when k > x.