Read-only web browser

Solutions

Browse the interview solutions folder with syntax highlighting.

non-overlapping-intervals.py

intervals/non-overlapping-intervals.py · Python · 3.1 KiB · 2026-04-20 19:51

Back to folder
# Greedy, recursion w/o memoization, top-down DP, bottom-up DP. Flip the question to "what is the max number of non-overlapping intervals we can keep?" and then subtract from total to get min number of removals.
# intervals


# Greedy - O(n log n) time, O(1) space
# Sort by time, then make a greedy, locally optimal choice at each step (keep the one with smaller end time)
# Removing the interval with the larger end is always the better choice, because keeping it would block more upcoming intervals
def eraseOverlapIntervals(intervals: list[list[int]]) -> int:
    intervals.sort(key=lambda x: x[0])

    remCount = 1
    prevEnd = intervals[0][1]
    for start, end in intervals[1:]:
        if start < prevEnd:
            # choose the one with smaller end time
            prevEnd = min(prevEnd, end)
            remCount += 1
        else:
            prevEnd = end

    return remCount


# Greedy (sorting by end) - O(n log n) time, O(1) space
def eraseOverlapIntervals(intervals: list[list[int]]) -> int:
    intervals.sort(key = lambda pair: pair[1])

    numrem = 0
    prevEnd = intervals[0][1]
    for start, end in intervals[1:]:
        if prevEnd > start:
            numrem += 1
        else:
            prevEnd = end

    return numrem
    
# Recursion w/o memoization - O(2^n) time, O(n) space
def eraseOverlapIntervals(intervals: list[list[int]]) -> int:
    intervals.sort(key=lambda x: x[0])

    # What is the max number of non-overlapping intervals we can keep 
    # starting from index i given that the last chosen interval is prev
    def dfs(i: int, prev: int) -> int:
        if i >= len(intervals):
            return 0

        skip = dfs(i+1, prev)

        if prev == -1 or intervals[prev][1] <= intervals[i][0]:
            take = 1 + dfs(i+1, i)
            return max(take, skip)
        else:
            return skip

    max_kept = dfs(0, -1)
    return len(intervals) - max_kept
        

# Top down DP - O(n^2) time, O(n) space
def eraseOverlapIntervals(intervals: list[list[int]]) -> int:
    intervals.sort(key = lambda x: x[1]) # NB: sort by end time
    n = len(intervals)
    memo = {}

    # If we choose interval i as part of our set,
    # what is the maximum number of non-overlapping intervals we can take starting from i?
    def dfs(i):
        if i in memo:
            return memo[i]

        res = 1
        for j in range(i + 1, n):
            if intervals[i][1] <= intervals[j][0]:
                res = max(res, 1 + dfs(j))
        memo[i] = res
        return res

    return n - dfs(0)


# Bottom-up DP - O(n^2) time, O(n) space
def eraseOverlapIntervals(intervals: list[list[int]]) -> int:
    intervals.sort(key = lambda x: x[1]) # NB: sort by end time
    n = len(intervals)

    # The maximum number of non-overlapping intervals we can keep ending at interval i (meaning interval i is included)
    dp = [0] * n

    for i in range(n):
        dp[i] = 1 # can always have this single interval only
        for j in range(i):
            if intervals[j][1] <= intervals[i][0]:
                dp[i] = max(dp[i], 1 + dp[j])

    max_non_overlapping = max(dp)
    return n - max_non_overlapping