Read-only web browser

Solutions

Browse the interview solutions folder with syntax highlighting.

min-cost-to-connect-all-points.py

advanced-graphs/min-cost-to-connect-all-points.py · Python · 2.1 KiB · 2026-04-26 10:54

Back to folder
# Minimum Spanning Tree: Prim's or Kruskal's. Both are O(E log E) time, O(E) space where E is number of edges. In this problem, E = O(n^2) since we have a complete graph.
# graphs


# Prim's - O(n^2 log n) time, O(n^2) space
def minCostConnectPoints(points: list[list[int]]) -> int:
    N = len(points)

    adj = {i:[] for i in range(N)} # node i: list of [cost, node]

    for i in range(N):
        x1, y1 = points[i]
        for j in range(i+1, N):
            x2, y2 = points[j]
            dst = abs(x1 - x2) + abs(y1 - y2)
            adj[i].append([dst, j])
            adj[j].append([dst, i])

    # Prim's
    import heapq
    res = 0
    visited = set()
    minheap = [[0,0]] # [cost, point] NB: order is important for sorting minheap!
    while len(visited) < N:
        cost, i = heapq.heappop(minheap)
        if i in visited:
            continue
        res += cost
        visited.add(i)

        for costj, j in adj[i]:
            if j in visited:
                continue
            heapq.heappush(minheap, [costj, j])

    return res


# Kruskal + Union Find - O(n^2 log n) time, O(n^2) space
class UnionFind:
    def __init__(self, n: int) -> None:
        self.parent = list(range(n+1))
        self.size = [1] * (n+1)

    def find(self, node: int) -> int:
        if self.parent[node] != node:
            self.parent[node] = self.find(self.parent[node])
        return self.parent[node]

    def union(self, u: int, v: int) -> bool:
        pu = self.find(u)
        pv = self.find(v)
        if pu == pv:
            return False

        if self.size[pu] < self.size[pv]:
            pu, pv = pv, pu
        self.size[pu] += self.size[pv]
        self.parent[pv] = pu
        return True

def minCostConnectPoints1(points: list[list[int]]) -> int:
    n = len(points)
    edges = []

    for i in range(n):
        x1, y1 = points[i]
        for j in range(i+1, n):
            x2, y2 = points[j]
            dst = abs(x1 - x2) + abs(y1 - y2)
            edges.append((dst, i, j))

    edges.sort()

    uf = UnionFind(len(points))
    totcost = 0
    for dst, p1, p2 in edges:
        if uf.union(p1,p2):
            totcost += dst

    return totcost