P2015 二叉苹果树

文章目录

题目信息

字段 内容
题号 P2015
难度 绿题(普及+/提高)
知识点 树形动态规划、背包DP

题目描述

有一棵苹果树,如果树枝有分叉,一定是分 2 叉(就是说没有只有 1 个儿子的结点)。这棵树共有 N 个结点(叶子点或者树枝分叉点),编号为 1∼N,共有 N-1 条树枝。

已知每条树枝上苹果的数量,求:

  • 给定需要保留的树枝数量 m
  • 计算能留住的最多苹果数

:留住一个苹果的定义是苹果所在枝条直接与根相连。

解题思路

核心思想

这是一道经典的树形背包DP问题。

  1. 树的存储:使用邻接表存储二叉树
  2. 状态定义dp[u][j] 表示以结点 u 为根的子树中,保留 j 条边时能获得的最大苹果数
  3. 状态转移
    • 对于每个子节点,分两种情况讨论
    • 不选该子节点:则该子树的边全部不保留
    • 该子节点:需要分配 k+1 条边(k 条给子树,1 条给连接父子节点的边)

算法流程

dfs(u, fa):
    d = 0  // 当前子树的总边数
    for each child v of u:
        child_edges = dfs(v, u) + 1  // 子树边数 + 连接边
        d += child_edges
        
        for j from d down to 1:
            for k from j-1 down to 0:
                dp[u][j] = max(dp[u][j], dp[u][j-k-1] + dp[v][k] + edge_value)
    return d

完整代码

C++

#include <cstdio>
#include <vector>

using namespace std;

const int kMaxN = 105;  // N <= 100,树上最多 100 个节点

struct Edge {
    int to;
    int val;
    Edge() {}
    Edge(int to_node, int weight) : to(to_node), val(weight) {}
};

// 邻接表:G[i] 存储节点 i 的所有邻接边
vector<Edge> G[kMaxN];

int n, m;
int dp[kMaxN][kMaxN];  // dp[u][j] = 以 u 为根的子树,保留 j 条边时的最大苹果数

/**
 * 深度优先搜索 + 树形背包 DP
 *
 * @param x 当前节点编号
 * @param fa 当前节点的父节点编号(防止往回走)
 * @return 以 x 为根的子树的边数
 */
int Dfs(int x, int fa) {
    int total_edges = 0;  // 当前已累计的边数(用于内层循环上界)

    for (int i = 0; i < (int)G[x].size(); ++i) {
        const Edge& e = G[x][i];
        if (e.to == fa) {
            continue;  // 跳过回边
        }

        // 递归处理子节点,返回子树的边数
        int child_subtree_edges = Dfs(e.to, x);

        // 将子节点的一条连接边计入(连接 x 与 e.to 的边)
        int child_edges = child_subtree_edges + 1;
        total_edges += child_edges;

        // 树形背包:枚举分配给 x 的总边数 j
        // j 的上界是 min(m, total_edges),因为最多只能保留 m 条边
        for (int j = min(m, total_edges); j > 0; --j) {
            // 枚举分配给当前子节点 e.to 的边数 k
            // k 可以从 0 到 min(m, j-1),其中 j-1 留给了 e.to 与 x 的连接边
            for (int k = min(m, j - 1); k >= 0; --k) {
                // 状态转移:
                // - 不选当前子节点:保持 dp[x][j]
                // - 选当前子节点:dp[x][j-k-1](给其他子树的)+ dp[e.to][k](当前子树的)+ e.val(连接边苹果)
                dp[x][j] = max(dp[x][j],
                               dp[x][j - k - 1] + dp[e.to][k] + e.val);
            }
        }
    }
    return total_edges;
}

int main() {
    // 读取输入:n 个节点,需要保留 m 条边
    scanf("%d%d", &n, &m);

    // 构建无向图(树)
    for (int i = 0; i < n - 1; ++i) {
        int x, y, v;
        scanf("%d%d%d", &x, &y, &v);
        G[x].push_back(Edge(y, v));
        G[y].push_back(Edge(x, v));
    }

    // 从根节点 1 开始 DFS
    Dfs(1, 0);

    // 答案:根节点保留 m 条边时的最大苹果数
    printf("%d\n", dp[1][m]);

    return 0;
}

关键代码讲解

状态转移方程

dp[x][j] = max(dp[x][j], dp[x][j - k - 1] + dp[s.to][k] + s.val);
  • j:分配给以 x 为根的子树的总边数
  • k:分配给子节点 s.to 子树的边数
  • k + 1:其中 1 条是连接 xs.to 的边
  • dp[x][j - k - 1]:分配给其他子树的边
  • dp[s.to][k]:子节点子树保留 k 条边的最大苹果数
  • s.val:连接边上的苹果数

为什么从大到小枚举?

for (reg int j = d; j >= 1; --j)

这是一维背包的常用技巧:

  • 由于 dp[x][j - k - 1] 引用的是同一层循环内的值
  • 从大到小遍历可以保证每件物品只被使用一次(0-1背包特性)

示例解析

假设一棵树:

    1(根)
   / \
  2   3
 / \   \
4   5   6

边权分别为:1-2:5, 1-3:3, 2-4:2, 2-5:4, 3-6:6

m = 2,选择边 1-2(5)2-5(4),最大苹果数 = 9

复杂度分析

方法 时间复杂度 空间复杂度
树形DP O(n × m^2) O(n × m)

易错点

  1. 数组大小dp[M+5][M+5]M=100,需确保 m <= 100
  2. 根节点选择:树以 1 为根进行 DFS
  3. 边数统计:每条边只能算一次,递归返回 child_edges = dfs(v,u) + 1
  4. 初始化dp 数组需初始化为 0

其他语言解法

Python

import sys
sys.setrecursionlimit(1000000)

def main():
    data = list(map(int, sys.stdin.read().split()))
    n, m = data[0], data[1]

    G = [[] for _ in range(n + 1)]
    idx = 2
    for _ in range(n - 1):
        x, y, v = data[idx], data[idx + 1], data[idx + 2]
        idx += 3
        G[x].append((y, v))
        G[y].append((x, v))

    dp = [[0] * (m + 1) for _ in range(n + 1)]

    def dfs(x, fa):
        d = 0
        for y, val in G[x]:
            if y == fa:
                continue
            child_edges = dfs(y, x) + 1
            d += child_edges

            for j in range(min(m, d), 0, -1):
                for k in range(min(m, j - 1), -1, -1):
                    dp[x][j] = max(dp[x][j],
                                   dp[x][j - k - 1] + dp[y][k] + val)
        return d

    dfs(1, 0)
    print(dp[1][m])


if __name__ == '__main__':
    main()

Java

import java.io.*;
import java.util.*;

public class Main {
    static int n, m;
    static List<int[]>[] G;
    static int[][] dp;

    static int dfs(int x, int fa) {
        int d = 0;
        for (int[] e : G[x]) {
            int y = e[0], val = e[1];
            if (y == fa) continue;

            int childEdges = dfs(y, x) + 1;
            d += childEdges;

            for (int j = Math.min(m, d); j >= 1; j--) {
                for (int k = Math.min(m, j - 1); k >= 0; k--) {
                    dp[x][j] = Math.max(dp[x][j],
                        dp[x][j - k - 1] + dp[y][k] + val);
                }
            }
        }
        return d;
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        n = Integer.parseInt(st.nextToken());
        m = Integer.parseInt(st.nextToken());

        G = new ArrayList[n + 1];
        for (int i = 1; i <= n; i++) G[i] = new ArrayList<>();

        for (int i = 0; i < n - 1; i++) {
            st = new StringTokenizer(br.readLine());
            int x = Integer.parseInt(st.nextToken());
            int y = Integer.parseInt(st.nextToken());
            int v = Integer.parseInt(st.nextToken());
            G[x].add(new int[]{y, v});
            G[y].add(new int[]{x, v});
        }

        dp = new int[n + 1][m + 1];
        dfs(1, 0);
        System.out.println(dp[1][m]);
    }
}

Go

package main

import (
    "bufio"
    "fmt"
    "os"
)

var n, m int
var G [][][2]int
var dp [][]int

func min(a, b int) int {
    if a < b {
        return a
    }
    return b
}

func max(a, b int) int {
    if a > b {
        return a
    }
    return b
}

func dfs(x, fa int) int {
    d := 0
    for _, e := range G[x] {
        y, val := e[0], e[1]
        if y == fa {
            continue
        }
        childEdges := dfs(y, x) + 1
        d += childEdges

        for j := min(m, d); j >= 1; j-- {
            for k := min(m, j-1); k >= 0; k-- {
                dp[x][j] = max(dp[x][j],
                    dp[x][j-k-1]+dp[y][k]+val)
            }
        }
    }
    return d
}

func main() {
    in := bufio.NewReader(os.Stdin)
    fmt.Fscan(in, &n, &m)

    G = make([][][2]int, n+1)
    for i := 1; i <= n; i++ {
        G[i] = make([][2]int, 0)
    }

    for i := 0; i < n-1; i++ {
        var x, y, v int
        fmt.Fscan(in, &x, &y, &v)
        G[x] = append(G[x], [2]int{y, v})
        G[y] = append(G[y], [2]int{x, v})
    }

    dp = make([][]int, n+1)
    for i := 1; i <= n; i++ {
        dp[i] = make([]int, m+1)
    }

    dfs(1, 0)
    fmt.Println(dp[1][m])
}

JavaScript

'use strict';

const fs = require('fs');

function main() {
    const data = fs.readFileSync('/dev/stdin', 'utf8').trim().split(/\s+/).map(Number);
    const n = data[0], m = data[1];

    const G = Array.from({ length: n + 1 }, () => []);
    let idx = 2;
    for (let i = 0; i < n - 1; i++) {
        const x = data[idx++], y = data[idx++], v = data[idx++];
        G[x].push([y, v]);
        G[y].push([x, v]);
    }

    const dp = Array.from({ length: n + 1 }, () => Array(m + 1).fill(0));

    function dfs(x, fa) {
        let d = 0;
        for (const [y, val] of G[x]) {
            if (y === fa) continue;

            const childEdges = dfs(y, x) + 1;
            d += childEdges;

            for (let j = d; j >= 1; j--) {
                for (let k = childEdges - 1; k >= 0; k--) {
                    if (j - k - 1 >= 0) {
                        dp[x][j] = Math.max(dp[x][j],
                            dp[x][j - k - 1] + dp[y][k] + val);
                    }
                }
            }
        }
        return d;
    }

    dfs(1, 0);
    console.log(dp[1][m]);
}

main();