P2015 二叉苹果树
文章目录
题目信息
| 字段 | 内容 |
|---|---|
| 题号 | P2015 |
| 难度 | 绿题(普及+/提高) |
| 知识点 | 树形动态规划、背包DP |
题目描述
有一棵苹果树,如果树枝有分叉,一定是分 2 叉(就是说没有只有 1 个儿子的结点)。这棵树共有 N 个结点(叶子点或者树枝分叉点),编号为 1∼N,共有 N-1 条树枝。
已知每条树枝上苹果的数量,求:
- 给定需要保留的树枝数量
m - 计算能留住的最多苹果数
注:留住一个苹果的定义是苹果所在枝条直接与根相连。
解题思路
核心思想
这是一道经典的树形背包DP问题。
- 树的存储:使用邻接表存储二叉树
- 状态定义:
dp[u][j]表示以结点u为根的子树中,保留j条边时能获得的最大苹果数 - 状态转移:
- 对于每个子节点,分两种情况讨论
- 不选该子节点:则该子树的边全部不保留
- 选该子节点:需要分配
k+1条边(k条给子树,1条给连接父子节点的边)
算法流程
dfs(u, fa):
d = 0 // 当前子树的总边数
for each child v of u:
child_edges = dfs(v, u) + 1 // 子树边数 + 连接边
d += child_edges
for j from d down to 1:
for k from j-1 down to 0:
dp[u][j] = max(dp[u][j], dp[u][j-k-1] + dp[v][k] + edge_value)
return d
完整代码
C++
#include <cstdio>
#include <vector>
using namespace std;
const int kMaxN = 105; // N <= 100,树上最多 100 个节点
struct Edge {
int to;
int val;
Edge() {}
Edge(int to_node, int weight) : to(to_node), val(weight) {}
};
// 邻接表:G[i] 存储节点 i 的所有邻接边
vector<Edge> G[kMaxN];
int n, m;
int dp[kMaxN][kMaxN]; // dp[u][j] = 以 u 为根的子树,保留 j 条边时的最大苹果数
/**
* 深度优先搜索 + 树形背包 DP
*
* @param x 当前节点编号
* @param fa 当前节点的父节点编号(防止往回走)
* @return 以 x 为根的子树的边数
*/
int Dfs(int x, int fa) {
int total_edges = 0; // 当前已累计的边数(用于内层循环上界)
for (int i = 0; i < (int)G[x].size(); ++i) {
const Edge& e = G[x][i];
if (e.to == fa) {
continue; // 跳过回边
}
// 递归处理子节点,返回子树的边数
int child_subtree_edges = Dfs(e.to, x);
// 将子节点的一条连接边计入(连接 x 与 e.to 的边)
int child_edges = child_subtree_edges + 1;
total_edges += child_edges;
// 树形背包:枚举分配给 x 的总边数 j
// j 的上界是 min(m, total_edges),因为最多只能保留 m 条边
for (int j = min(m, total_edges); j > 0; --j) {
// 枚举分配给当前子节点 e.to 的边数 k
// k 可以从 0 到 min(m, j-1),其中 j-1 留给了 e.to 与 x 的连接边
for (int k = min(m, j - 1); k >= 0; --k) {
// 状态转移:
// - 不选当前子节点:保持 dp[x][j]
// - 选当前子节点:dp[x][j-k-1](给其他子树的)+ dp[e.to][k](当前子树的)+ e.val(连接边苹果)
dp[x][j] = max(dp[x][j],
dp[x][j - k - 1] + dp[e.to][k] + e.val);
}
}
}
return total_edges;
}
int main() {
// 读取输入:n 个节点,需要保留 m 条边
scanf("%d%d", &n, &m);
// 构建无向图(树)
for (int i = 0; i < n - 1; ++i) {
int x, y, v;
scanf("%d%d%d", &x, &y, &v);
G[x].push_back(Edge(y, v));
G[y].push_back(Edge(x, v));
}
// 从根节点 1 开始 DFS
Dfs(1, 0);
// 答案:根节点保留 m 条边时的最大苹果数
printf("%d\n", dp[1][m]);
return 0;
}
关键代码讲解
状态转移方程
dp[x][j] = max(dp[x][j], dp[x][j - k - 1] + dp[s.to][k] + s.val);
j:分配给以x为根的子树的总边数k:分配给子节点s.to子树的边数k + 1:其中 1 条是连接x与s.to的边dp[x][j - k - 1]:分配给其他子树的边dp[s.to][k]:子节点子树保留k条边的最大苹果数s.val:连接边上的苹果数
为什么从大到小枚举?
for (reg int j = d; j >= 1; --j)
这是一维背包的常用技巧:
- 由于
dp[x][j - k - 1]引用的是同一层循环内的值 - 从大到小遍历可以保证每件物品只被使用一次(0-1背包特性)
示例解析
假设一棵树:
1(根)
/ \
2 3
/ \ \
4 5 6
边权分别为:1-2:5, 1-3:3, 2-4:2, 2-5:4, 3-6:6
若 m = 2,选择边 1-2(5) 和 2-5(4),最大苹果数 = 9
复杂度分析
| 方法 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| 树形DP | O(n × m^2) |
O(n × m) |
易错点
- 数组大小:
dp[M+5][M+5]中M=100,需确保m <= 100 - 根节点选择:树以
1为根进行 DFS - 边数统计:每条边只能算一次,递归返回
child_edges = dfs(v,u) + 1 - 初始化:
dp数组需初始化为 0
其他语言解法
Python
import sys
sys.setrecursionlimit(1000000)
def main():
data = list(map(int, sys.stdin.read().split()))
n, m = data[0], data[1]
G = [[] for _ in range(n + 1)]
idx = 2
for _ in range(n - 1):
x, y, v = data[idx], data[idx + 1], data[idx + 2]
idx += 3
G[x].append((y, v))
G[y].append((x, v))
dp = [[0] * (m + 1) for _ in range(n + 1)]
def dfs(x, fa):
d = 0
for y, val in G[x]:
if y == fa:
continue
child_edges = dfs(y, x) + 1
d += child_edges
for j in range(min(m, d), 0, -1):
for k in range(min(m, j - 1), -1, -1):
dp[x][j] = max(dp[x][j],
dp[x][j - k - 1] + dp[y][k] + val)
return d
dfs(1, 0)
print(dp[1][m])
if __name__ == '__main__':
main()
Java
import java.io.*;
import java.util.*;
public class Main {
static int n, m;
static List<int[]>[] G;
static int[][] dp;
static int dfs(int x, int fa) {
int d = 0;
for (int[] e : G[x]) {
int y = e[0], val = e[1];
if (y == fa) continue;
int childEdges = dfs(y, x) + 1;
d += childEdges;
for (int j = Math.min(m, d); j >= 1; j--) {
for (int k = Math.min(m, j - 1); k >= 0; k--) {
dp[x][j] = Math.max(dp[x][j],
dp[x][j - k - 1] + dp[y][k] + val);
}
}
}
return d;
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
n = Integer.parseInt(st.nextToken());
m = Integer.parseInt(st.nextToken());
G = new ArrayList[n + 1];
for (int i = 1; i <= n; i++) G[i] = new ArrayList<>();
for (int i = 0; i < n - 1; i++) {
st = new StringTokenizer(br.readLine());
int x = Integer.parseInt(st.nextToken());
int y = Integer.parseInt(st.nextToken());
int v = Integer.parseInt(st.nextToken());
G[x].add(new int[]{y, v});
G[y].add(new int[]{x, v});
}
dp = new int[n + 1][m + 1];
dfs(1, 0);
System.out.println(dp[1][m]);
}
}
Go
package main
import (
"bufio"
"fmt"
"os"
)
var n, m int
var G [][][2]int
var dp [][]int
func min(a, b int) int {
if a < b {
return a
}
return b
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func dfs(x, fa int) int {
d := 0
for _, e := range G[x] {
y, val := e[0], e[1]
if y == fa {
continue
}
childEdges := dfs(y, x) + 1
d += childEdges
for j := min(m, d); j >= 1; j-- {
for k := min(m, j-1); k >= 0; k-- {
dp[x][j] = max(dp[x][j],
dp[x][j-k-1]+dp[y][k]+val)
}
}
}
return d
}
func main() {
in := bufio.NewReader(os.Stdin)
fmt.Fscan(in, &n, &m)
G = make([][][2]int, n+1)
for i := 1; i <= n; i++ {
G[i] = make([][2]int, 0)
}
for i := 0; i < n-1; i++ {
var x, y, v int
fmt.Fscan(in, &x, &y, &v)
G[x] = append(G[x], [2]int{y, v})
G[y] = append(G[y], [2]int{x, v})
}
dp = make([][]int, n+1)
for i := 1; i <= n; i++ {
dp[i] = make([]int, m+1)
}
dfs(1, 0)
fmt.Println(dp[1][m])
}
JavaScript
'use strict';
const fs = require('fs');
function main() {
const data = fs.readFileSync('/dev/stdin', 'utf8').trim().split(/\s+/).map(Number);
const n = data[0], m = data[1];
const G = Array.from({ length: n + 1 }, () => []);
let idx = 2;
for (let i = 0; i < n - 1; i++) {
const x = data[idx++], y = data[idx++], v = data[idx++];
G[x].push([y, v]);
G[y].push([x, v]);
}
const dp = Array.from({ length: n + 1 }, () => Array(m + 1).fill(0));
function dfs(x, fa) {
let d = 0;
for (const [y, val] of G[x]) {
if (y === fa) continue;
const childEdges = dfs(y, x) + 1;
d += childEdges;
for (let j = d; j >= 1; j--) {
for (let k = childEdges - 1; k >= 0; k--) {
if (j - k - 1 >= 0) {
dp[x][j] = Math.max(dp[x][j],
dp[x][j - k - 1] + dp[y][k] + val);
}
}
}
}
return d;
}
dfs(1, 0);
console.log(dp[1][m]);
}
main();