P1063 能量项链

文章目录

题目信息

字段 内容
题号 P1063
难度 蓝题(NOIP 2006 提高组)
知识点 动态规划、区间 DP、环形 DP

题目描述

在 Mars 星球上,每个 Mars 人都随身佩带着一串能量项链。项链上有 N 颗能量珠,每颗能量珠都有头标记尾标记

相邻两颗珠子满足:前一颗的尾标记 = 后一颗的头标记。

当两颗相邻珠子聚合时,若前一颗为 (m, r),后一颗为 (r, n),则释放能量 m × r × n,生成的新珠子标记为 (m, n)

最终项链会聚合为一颗珠子,请你设计聚合顺序,使释放的总能量最大。

输入格式

  • 第一行:整数 N(珠子个数)
  • 第二行:N 个整数,表示第 i 颗珠子的头标记;第 i 颗珠子的尾标记 = 第 i+1 颗珠子的头标记;第 N 颗珠子的尾标记 = 第 1 颗珠子的头标记

输出格式

一个整数,表示最优聚合顺序释放的最大总能量

样例

输入

4
2 3 5 10

输出

710

样例解析: 最优顺序 ((4⊕1)⊕2)⊕3,各次聚合能量为 60、150、500,总和 710。

数据范围

  • N ≤ 100
  • 标记值 ≤ 1000
  • 最大总能量 ≤ 2.1 × 10⁹

解题思路

核心思想

本题是区间 DP 的经典问题,与"矩阵连乘"本质相同。

由于项链是环形的,处理方法是:将珠子序列在末尾复制一份,例如 a[1..N] 展开为 a[1..2N],枚举每一种以 i 为起点的长度为 N 的区间。

状态定义

dp[i][j] 表示将i 颗到第 j 颗珠子(按顺序连续)合并为一颗珠子所释放的最大总能量

合并后新珠子的标记为:头 = a[i]尾 = a[j+1](因为第 j 颗珠子的尾标记 = 第 j+1 颗珠子的头标记)。

状态转移

枚举最后一次聚合的分割点 ki ≤ k < j),此时 [i, j] 被分成 [i, k][k+1, j] 两部分:

dp[i][j] = max( dp[i][k] + dp[k+1][j] + a[i] × a[k+1] × a[j+1] )

解释:

  • dp[i][k]:合并前一段的最大能量,合并后珠子标记为 (a[i], a[k+1])
  • dp[k+1][j]:合并后一段的最大能量,合并后珠子标记为 (a[k+1], a[j+1])
  • 最后这两颗再合并:a[i] × a[k+1] × a[j+1]

环形处理

遍历 i = 1..N,以 i 为起点,求解区间 [i, i+N-1] 的 DP 值,取最大即为答案。这样等价于穷举了环上所有可能的起始位置。

完整代码

C++

#include <iostream>
using namespace std;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n;
    cin >> n;
    int a[310] = {0};

    // 读入并复制一份(处理环形)
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
        a[i + n] = a[i];          // 复制到 i+n 处
    }

    // dp[i][j] 表示合并 [i..j] 的最大能量
    int dp[310][310] = {0};
    int ans = -1;

    // 区间长度从 2 到 n
    for (int len = 2; len <= n; len++) {
        // 起点 i,最大到 2n-len+1
        for (int i = 1; i + len - 1 <= 2 * n; i++) {
            int j = i + len - 1;    // 区间右端点
            // 枚举分割点 k
            for (int k = i; k < j; k++) {
                int head = a[i];
                int mid  = a[k + 1];
                int tail = a[j + 1];
                dp[i][j] = max(dp[i][j], dp[i][k] + dp[k + 1][j] + head * mid * tail);
            }
            // 如果区间长度恰好为 n(完整一圈),更新答案
            if (len == n) {
                ans = max(ans, dp[i][j]);
            }
        }
    }

    cout << ans << '\n';
    return 0;
}

代码核心解读

dp[i][j] = max(dp[i][j], dp[i][k] + dp[k+1][j] + a[i] * a[k+1] * a[j+1]);
  • a[i]:第 i 颗珠子的头标记,也是 [i, k] 合并后新珠子的头标记
  • a[k+1]:分界线,既是左段的尾标记也是右段的头标记
  • a[j+1]:第 j 颗珠子的尾标记,也是右段合并后新珠子的尾标记
  • a[i] * a[k+1] * a[j+1]:最后一步合并这两颗珠子释放的能量

关键代码讲解

环形展开

for (int i = 1; i <= n; i++) {
    cin >> a[i];
    a[i + n] = a[i];
}

将数组复制一份,使得原本环形的链可以用线性区间 [i, j] 来表示,不必在 DP 中处理环的回绕逻辑。

区间 DP 遍历顺序

for (int len = 2; len <= n; len++) {          // 区间长度
    for (int i = 1; i + len - 1 <= 2 * n; i++) {  // 起点
        int j = i + len - 1;                   // 终点
        for (int k = i; k < j; k++) {          // 枚举分割点

区间 DP 必须按长度递增的顺序遍历,这样才能保证 dp[i][k]dp[k+1][j] 已经被计算过。

取完整环的答案

if (len == n) {
    ans = max(ans, dp[i][j]);
}

只有区间长度恰好等于 n 时,才是覆盖了完整一圈的合法方案,此时取最大值。

复杂度分析

方法 时间复杂度 空间复杂度
区间 DP O(n³) O(n²)
  • 长度 O(n) × 起点 O(n) × 分割点 O(n) → O(n³)
  • n ≤ 100,约 10⁶ 次运算,可通过

易错点

  1. 环形数组复制:必须将数组复制到 a[i+n],否则取 a[j+1] 时会越界(当 j = 2n 时)
  2. dp 数组下标a[j+1]j 最大为 2n-1,所以 a 数组大小至少 2n+2
  3. 区间长度从 2 开始:单个珠子 len=1 无需聚合,直接 dp[i][i]=0
  4. 答案不是 dp[1][n]:因为是环形,正确答案是 max(dp[i][i+n-1]) for i = 1..n

其他语言解法

Python

def solve() -> None:
    n: int = int(input())
    beads = list(map(int, input().split()))

    # a[0]=0(占位), a[1..n]=beads, a[n+1..2n]=beads(环形展开)
    # a[2n+1] = beads[0],第 2n 颗珠子尾标记 = 第 1 颗头标记
    a: list[int] = [0] + beads + beads + [beads[0]]

    dp = [[0] * (2 * n + 2) for _ in range(2 * n + 2)]
    ans = 0

    # 区间长度从 2 到 n
    for length in range(2, n + 1):
        for i in range(1, 2 * n - length + 2):
            j = i + length - 1
            for k in range(i, j):
                val = dp[i][k] + dp[k + 1][j] + a[i] * a[k + 1] * a[j + 1]
                if dp[i][j] < val:
                    dp[i][j] = val
            if length == n:
                ans = max(ans, dp[i][j])

    print(ans)


if __name__ == '__main__':
    solve()

Java

import java.io.*;
import java.util.*;

public class Main {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine().trim());
        StringTokenizer st = new StringTokenizer(br.readLine());

        int[] a = new int[2 * n + 2];
        for (int i = 1; i <= n; i++) {
            a[i] = Integer.parseInt(st.nextToken());
            a[i + n] = a[i];
        }

        int[][] dp = new int[2 * n + 2][2 * n + 2];
        int ans = -1;

        for (int len = 2; len <= n; len++) {
            for (int i = 1; i + len - 1 <= 2 * n; i++) {
                int j = i + len - 1;
                for (int k = i; k < j; k++) {
                    int val = dp[i][k] + dp[k + 1][j]
                            + a[i] * a[k + 1] * a[j + 1];
                    dp[i][j] = Math.max(dp[i][j], val);
                }
                if (len == n) {
                    ans = Math.max(ans, dp[i][j]);
                }
            }
        }
        System.out.println(ans);
    }
}

Go

package main

import (
	"bufio"
	"fmt"
	"os"
)

func max(a, b int) int {
	if a > b {
		return a
	}
	return b
}

func main() {
	in := bufio.NewReader(os.Stdin)
	var n int
	fmt.Fscan(in, &n)

	a := make([]int, 2*n+2)
	for i := 1; i <= n; i++ {
		fmt.Fscan(in, &a[i])
		a[i+n] = a[i]
	}

	dp := make([][]int, 2*n+2)
	for i := range dp {
		dp[i] = make([]int, 2*n+2)
	}
	ans := -1

	for length := 2; length <= n; length++ {
		for i := 1; i+length-1 <= 2*n; i++ {
			j := i + length - 1
			for k := i; k < j; k++ {
				val := dp[i][k] + dp[k+1][j] + a[i]*a[k+1]*a[j+1]
				if dp[i][j] < val {
					dp[i][j] = val
				}
			}
			if length == n && dp[i][j] > ans {
				ans = dp[i][j]
			}
		}
	}
	fmt.Println(ans)
}

JavaScript

'use strict';

/**
 * @param {number} n - 珠子数量
 * @param {number[]} beads - 头标记数组(下标从1开始)
 * @returns {number} - 最大总能量
 */
function solve(n, beads) {
    // 环形展开:复制一份
    const a = new Array(2 * n + 2).fill(0);
    for (let i = 1; i <= n; i++) {
        a[i] = beads[i - 1];
        a[i + n] = a[i];
    }

    const dp = Array.from({ length: 2 * n + 2 }, () => new Int32Array(2 * n + 2));
    let ans = -1;

    // 区间长度从2到n
    for (let len = 2; len <= n; len++) {
        for (let i = 1; i + len - 1 <= 2 * n; i++) {
            const j = i + len - 1;
            for (let k = i; k < j; k++) {
                const val = dp[i][k] + dp[k + 1][j] + a[i] * a[k + 1] * a[j + 1];
                if (dp[i][j] < val) dp[i][j] = val;
            }
            if (len === n && dp[i][j] > ans) {
                ans = dp[i][j];
            }
        }
    }
    return ans;
}

const rl = require('readline').createInterface({ input: process.stdin });
const lines = [];
rl.on('line', line => lines.push(line.trim()));
rl.on('close', () => {
    const n = parseInt(lines[0]);
    const beads = lines[1].split(' ').map(Number);
    process.stdout.write(String(solve(n, beads)));
});