package com.freetymekiyan.algorithms.level.medium;

import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;

/**
 * 373. Find K Pairs with Smallest Sums
 * <p>
 * You are given two integer arrays nums1 and nums2 sorted in ascending order and an integer k.
 * <p>
 * Define a pair (u,v) which consists of one element from the first array and one element from the second array.
 * <p>
 * Find the k pairs (u1,v1),(u2,v2) ...(uk,vk) with the smallest sums.
 * <p>
 * Example 1:
 * Given nums1 = [1,7,11], nums2 = [2,4,6],  k = 3
 * <p>
 * Return: [1,2],[1,4],[1,6]
 * <p>
 * The first 3 pairs are returned from the sequence:
 * [1,2],[1,4],[1,6],[7,2],[7,4],[11,2],[7,6],[11,4],[11,6]
 * <p>
 * Example 2:
 * Given nums1 = [1,1,2], nums2 = [1,2,3],  k = 2
 * <p>
 * Return: [1,1],[1,1]
 * <p>
 * The first 2 pairs are returned from the sequence:
 * [1,1],[1,1],[1,2],[2,1],[1,2],[2,2],[1,3],[1,3],[2,3]
 * <p>
 * Example 3:
 * Given nums1 = [1,2], nums2 = [3],  k = 3
 * <p>
 * Return: [1,3],[2,3]
 * <p>
 * All possible pairs are returned from the sequence:
 * [1,3],[2,3]
 * <p>
 * Tags: Heap
 * Similar Problems: (M) Kth Smallest Element in a Sorted Matrix
 */
public class FindKPairsWithSmallestSums {

  /**
   * Heap.
   * Suppose len(nums1) = m, len(nums2) = n, there can be m * n pairs.
   * All integers in nums1 should start from the first integer in nums2.
   * So add the first k pairs to a priority queue pq based on the sum.
   * Then poll from pq and add it to result.
   * Move the index in nums2 for the integer in nums1 one step further if possible.
   * Then add the new pair to pq.
   * Stop when we have k pairs.
   * https://discuss.leetcode.com/topic/50885/simple-java-o-klogk-solution-with-explanation/2
   */
  public List<int[]> kSmallestPairs(int[] nums1, int[] nums2, int k) {
    PriorityQueue<int[]> pq = new PriorityQueue<>((a, b) -> a[0] + a[1] - b[0] - b[1]);
    List<int[]> res = new ArrayList<>();
    if (nums1.length == 0 || nums2.length == 0 || k == 0) {
      return res;
    }
    for (int i = 0; i < nums1.length && i < k; i++) {
      pq.offer(new int[]{nums1[i], nums2[0], 0});
    }
    while (k-- > 0 && !pq.isEmpty()) {
      int[] cur = pq.poll();
      res.add(new int[]{cur[0], cur[1]});
      if (cur[2] == nums2.length - 1) {
        continue;
      }
      pq.offer(new int[]{cur[0], nums2[cur[2] + 1], cur[2] + 1});
    }
    return res;
  }
}