Read-only web browser

Solutions

Browse the interview solutions folder with syntax highlighting.

k-closest-points-to-origin.py

heap/k-closest-points-to-origin.py · Python · 1.9 KiB · 2026-03-16 09:04

Back to folder
# Min heap, max heap, sorting or quickselect.
# heap

# Min heap - O(n + k log n) time, O(n) space
def kClosest(points: list[list[int]], k: int) -> list[list[int]]:
    import heapq
    import math

    minheap : list = []

    for point in points:
        dist = math.sqrt(point[0]**2 + point[1]**2)
        heapq.heappush(minheap, [dist, point[0], point[1]])
        
    res = []
    for _ in range(k):
        _, x, y = heapq.heappop(minheap)
        res.append([x, y])

    return res


# Max heap - O(n log k) time, O(k) space
def kClosest(points: list[list[int]], k: int) -> list[list[int]]:
    import heapq

    maxHeap = []
    for x, y in points:
        dist = -(x ** 2 + y ** 2)
        heapq.heappush(maxHeap, [dist, x, y])
        if len(maxHeap) > k:
            heapq.heappop(maxHeap)

    res = []
    while maxHeap:
        dist, x, y = heapq.heappop(maxHeap)
        res.append([x, y])
    return res


# Sorting - O(n log n) time, O(n) space
def kClosest(points: list[list[int]], k: int) -> list[list[int]]:
    points.sort(key=lambda p: p[0]**2 + p[1]**2)
    return points[:k]


# Quickselect - O(n) average time, O(n^2) worst-case time, O(1) space
def kClosest(points, k):
    euclidean = lambda x: x[0] ** 2 + x[1] ** 2
    def partition(l, r):
        pivotIdx = r
        pivotDist = euclidean(points[pivotIdx])
        i = l
        for j in range(l, r):
            # re-arrange the points so that closter points are on the left and farther points are on the right of pivot
            if euclidean(points[j]) <= pivotDist:
                points[i], points[j] = points[j], points[i]
                i += 1
        points[i], points[r] = points[r], points[i]
        return i

    L, R = 0, len(points) - 1
    pivot = len(points)

    while pivot != k:
        pivot = partition(L, R)
        if pivot < k:
            L = pivot + 1
        else:
            R = pivot - 1
    return points[:k]