P1868 饥饿的奶牛

文章目录

题目信息

字段 内容
题号 P1868
难度 普及+
知识点 动态规划、线性 DP

题目描述

有一条奶牛冲出了围栏,来到了一处圣地(对于奶牛来说),上面用牛语写着一段文字。

现用汉语翻译为:

N 个区间,每个区间 x y 表示提供 [x, y]y - x + 1 堆优质牧草。你可以选择任意区间但不能有重复的部分。奶牛希望吃的牧草越多越好,请你帮助它。

输入格式

  • 第一行:整数 N(区间个数)
  • 接下来 N 行:每行两个整数 x y(表示一个区间 [x, y]

输出格式

一个整数,表示最多能吃到的牧草堆数(即不重叠区间的最大覆盖总长度)

样例

输入

3
1 3
7 8
3 4

输出

5

样例解析: 选择区间 [1, 3](3 堆)和 [7, 8](2 堆),共 5 堆。注意 [1, 3][3, 4] 在位置 3 重叠,不能同时选。

数据范围

  • N ≤ 150,000
  • 0 ≤ x ≤ y ≤ 3,000,000

解题思路

核心思想

本题本质上是不重叠区间选最大覆盖长度问题,可转化为线性 DP

定义 dp[i]到位置 i 为止(即覆盖区间 [0, i])能获得的最大牧草堆数。答案为 dp[maxY]

状态转移方程

对于每个区间 [l, r](长度为 len = r - l + 1),有:

dp[r] = max(dp[r], dp[l - 1] + len)

解释:若选择这个区间,则其左侧到 l-1 位置的最优值 dp[l-1] 加上该区间的长度,就是覆盖到 r 位置的最优值。

同时,每一步都要保证不选当前区间的状态也能传递下去:

dp[i] = max(dp[i], dp[i - 1])

实现细节

  1. 坐标范围大:右端点最大到 3,000,000,无法用纯坐标 DP(会 MLE/TLE),但可用压缩到最大右端点的线性遍历
  2. 按左端点排序:保证处理每个位置时,所有起点为该位置的区间都已准备好
  3. 指针扫描:用指针 j 在排序后的区间数组上移动,时间复杂度 O(maxY + N)

完整代码

C++

#include <iostream>
#include <algorithm>
using namespace std;

const int MAXN = 5000000 + 10;

struct Node {
    int x, y;
} seg[MAXN];

int dp[MAXN];
int n, maxY, ans;

bool cmp(const Node& a, const Node& b) {
    return a.x == b.x ? a.y < b.y : a.x < b.x;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    cin >> n;
    for (int i = 1; i <= n; i++) {
        cin >> seg[i].x >> seg[i].y;
        maxY = max(maxY, seg[i].y);
    }

    sort(seg + 1, seg + n + 1, cmp);

    int j = 1;
    for (int i = 0; i <= maxY; i++) {
        // 不选当前位置区间的最优值传递下去
        dp[i] = max(dp[i], dp[i - 1]);

        // 处理所有左端点恰好为 i 的区间
        while (seg[j].x == i && j <= n) {
            int len = seg[j].y - seg[j].x + 1;
            dp[seg[j].y] = max(dp[seg[j].y], dp[seg[j].x - 1] + len);
            ans = max(ans, dp[seg[j].y]);
            j++;
        }
    }

    cout << ans << '\n';
    return 0;
}

关键代码讲解

按左端点排序

sort(seg + 1, seg + n + 1, cmp);

对所有区间按左端点升序排序,保证在遍历坐标轴时,可以用指针线性扫描找到所有以当前坐标为左端点的区间。

线性 DP 遍历

for (int i = 0; i <= maxY; i++) {
    dp[i] = max(dp[i], dp[i - 1]);
    while (seg[j].x == i && j <= n) {
        dp[seg[j].y] = max(dp[seg[j].y], dp[seg[j].x - 1] + len);
        j++;
    }
    ans = max(ans, dp[i]);
}
  • dp[i] = max(dp[i], dp[i-1]):将之前的最优值传递到当前坐标,保证即使不选任何以 i 为左端点的区间,也不会丢失之前的最优解
  • while 循环处理所有起点为 i 的区间,一次遍历完成所有状态转移
  • ans 记录过程中的最大值,最终输出

状态转移

dp[seg[j].y] = max(dp[seg[j].y], dp[seg[j].x - 1] + len);

dp[l - 1] 是覆盖到 l-1 位置的最优值,加上新区间的长度 r - l + 1,得到覆盖到 r 位置的新最优值。注意左端点可能为 0,需要保证 dp 数组下标非负。

复杂度分析

方法 时间复杂度 空间复杂度
线性 DP O(maxY + N) O(maxY)

maxY ≤ 3,000,000N ≤ 150,000,均在可接受范围内。

易错点

  1. 数组越界:当 x = 0 时,dp[x - 1]dp[-1],需要特殊处理或令 dp 数组下标从 1 开始。代码中通过令 i 从 0 开始遍历规避了此问题(dp[-1]i=0 时不访问)
  2. 排序规则:必须严格按左端点排序,否则指针扫描逻辑会出错
  3. 区间端点不唯一:同一左端点可能有多个区间,while 循环需正确处理

其他语言解法

Python

def solve() -> None:
    import sys
    input_data = sys.stdin.read().split()
    n = int(input_data[0])
    segs = []
    max_y = 0

    idx = 1
    for _ in range(n):
        x = int(input_data[idx]); y = int(input_data[idx + 1])
        segs.append((x, y))
        max_y = max(max_y, y)
        idx += 2

    segs.sort(key=lambda t: (t[0], t[1]))

    dp = [0] * (max_y + 1)
    ans = 0
    j = 0

    for i in range(max_y + 1):
        if i > 0:
            dp[i] = max(dp[i], dp[i - 1])
        while j < n and segs[j][0] == i:
            x, y = segs[j]
            length = y - x + 1
            # dp[x-1] 需要特殊处理 x==0 的情况
            prev = dp[x - 1] if x > 0 else 0
            dp[y] = max(dp[y], prev + length)
            ans = max(ans, dp[y])
            j += 1

    print(ans)


if __name__ == '__main__':
    solve()

Java

import java.io.*;
import java.util.*;

public class Main {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine().trim());

        int[][] segs = new int[n][2];
        int maxY = 0;
        for (int i = 0; i < n; i++) {
            String[] parts = br.readLine().trim().split("\\s+");
            segs[i][0] = Integer.parseInt(parts[0]);
            segs[i][1] = Integer.parseInt(parts[1]);
            maxY = Math.max(maxY, segs[i][1]);
        }

        Arrays.sort(segs, Comparator.comparingInt((int[] a) -> a[0]));

        int[] dp = new int[maxY + 1];
        int ans = 0;
        int j = 0;

        for (int i = 0; i <= maxY; i++) {
            if (i > 0) dp[i] = Math.max(dp[i], dp[i - 1]);
            while (j < n && segs[j][0] == i) {
                int x = segs[j][0], y = segs[j][1];
                int prev = x > 0 ? dp[x - 1] : 0;
                dp[y] = Math.max(dp[y], prev + y - x + 1);
                ans = Math.max(ans, dp[y]);
                j++;
            }
        }
        System.out.println(ans);
    }
}

Go

package main

import (
	"bufio"
	"fmt"
	"os"
	"sort"
)

func main() {
	in := bufio.NewReader(os.Stdin)
	var n int
	fmt.Fscan(in, &n)

	type seg struct{ x, y int }
	segs := make([]seg, n)
	maxY := 0
	for i := 0; i < n; i++ {
		fmt.Fscan(in, &segs[i].x, &segs[i].y)
		if segs[i].y > maxY {
			maxY = segs[i].y
		}
	}

	sort.Slice(segs, func(i, j int) bool {
		if segs[i].x == segs[j].x {
			return segs[i].y < segs[j].y
		}
		return segs[i].x < segs[j].x
	})

	dp := make([]int, maxY+1)
	ans := 0
	j := 0

	for i := 0; i <= maxY; i++ {
		if i > 0 && dp[i-1] > dp[i] {
			dp[i] = dp[i-1]
		}
		for j < n && segs[j].x == i {
			x, y := segs[j].x, segs[j].y
			prev := 0
			if x > 0 {
				prev = dp[x-1]
			}
			length := y - x + 1
			if dp[y] < prev+length {
				dp[y] = prev + length
			}
			if dp[y] > ans {
				ans = dp[y]
			}
			j++
		}
	}
	fmt.Println(ans)
}

JavaScript

'use strict';

/**
 * @param {number} n - 区间数量
 * @param {[number, number][]} segs - [[x, y], ...]
 * @returns {number} - 最大牧草堆数
 */
function solve(n, segs) {
    segs.sort((a, b) => a[0] - b[0] || a[1] - b[1]);

    const maxY = segs.reduce((mx, s) => Math.max(mx, s[1]), 0);
    const dp = new Int32Array(maxY + 1);
    let ans = 0;
    let j = 0;

    for (let i = 0; i <= maxY; i++) {
        if (i > 0) dp[i] = Math.max(dp[i], dp[i - 1]);
        while (j < n && segs[j][0] === i) {
            const [x, y] = segs[j];
            const prev = x > 0 ? dp[x - 1] : 0;
            dp[y] = Math.max(dp[y], prev + y - x + 1);
            ans = Math.max(ans, dp[y]);
            j++;
        }
    }
    return ans;
}

const input = require('fs').readFileSync('/dev/stdin', 'utf8').trim().split('\n');
const n = parseInt(input[0]);
const segs = [];
for (let i = 1; i <= n; i++) {
    const [x, y] = input[i].split(' ').map(Number);
    segs.push([x, y]);
}
process.stdout.write(String(solve(n, segs)));