奥法之劫 题解

这里是 O(n)O(n) 的做法。

我们先从一个朴素的 O(n3)O(n^3) dp讲起。

f[i]f[i] 表示我处理到 ii 号节点的最小代价和,那么转移有:

f[i]=min(f[k]+cost(k+1,i))f[i]=\min(f[k]+\mathrm{cost}(k+1,i))

那么我们可以进行一步步优化到 O(n)O(n)


先关注 pi0p_i\geqslant 0 的情况,有一些我们不得不拆的 aia_i,我把它们称作“代价”。

这些代价不仅和当前 ii 的位置是有关,还与 bb 的高度有关。

我们发现代价是不好计算的。随着 ii 的变化,代价也随之变化,所以每次都要 O(n)O(n) 扫描。

换一种思路,我们可以考虑每个 aia_i 对答案的贡献,把它们挂在对应的节点上,就可以快速计算。

也就是说,我们要找到对于每一个 aka_k,找到第一个 aia_i,使得选中 aia_i 时必须要拆除 aka_k

具体而言,我们可以在 bb 数组中 lower_bound aklower\_bound\ a_k,把 pkp_k 加到对应的高度上。

那么每次我想选一个 aia_i,就必须要拆除所有挂在 bjb_j 上的代价,因为它们没有被选中且未被挡住。(其中 ai=bja_i=b_j

这样我们砍掉了计算区间贡献的 nn,而换成了 logn\log n

再想想,由于 bb 数组是单调的,所以我们可以一遍扫一遍存下来 lower_boundlower\_bound 的结果,就可以做到 O(1)O(1) 了。


但此时,我们的转移点还是不确定,并且没有考虑 pi<0p_i<0 的情况。

对于 pi<0p_i<0 的部分,我们贪心的想肯定是越选多越好,所以除了把 ii 选中的情况我们的答案都应该加上这些 pip_i

没错,这里的影响就只和 ii 的位置有关了,同样有关的还有我们的 ff 数组。

其实我们 dpdp 的本质,就是在最小化这些东西,而之前的操作是为了方便计算必须要拆的代价。

我们设 aia_i 对应的 bb 的位置为 posipos_iii 之前 pi<0p_i<0 部分的和为 sumisum_i,那么转移有:

f[i]=minposk<posi(f[k]sumk)+sumi+costjf[i]=\min_{pos_k<pos_i}(f[k]-sum_k)+sum_i+cost_{j}

其中 costjcost_j 表示在 ii 位置保留 bjb_j 高度的代价。

我们发现 minmin 里面的可以记一个前缀最小值,这样就可以 O(1)O(1) 转移了。

这样最小化了能够最小化的部分,也统计了代价。


最后注意几点:

  1. 为了保证最后答案能够统计到,要在 n+1n+1 的位置插一个极大值。

  2. 如果存在 aibja_i\neq b_j 对于任意的 jj ,那么它不能参与统计答案,但必须要参与代价以及 sumsum 的计算。

  3. 其实 ffsumsum 数组根本没必要开。~

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define ll long long
const int N=5e6+9;
const ll INF=1e18;
inline int read(){
	int res=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){
		if(ch=='-') f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9'){
		res=(res<<1)+(res<<3)+ch-'0';
		ch=getchar();
	}
	return res*f;
}
int n,a[N],b[N],Q[N],m,pos[N];
ll f[N],g[N],cost[N],p[N],sum;
int main()
{
	// freopen("hs.in","r",stdin);
	// freopen("offa.out","w",stdout);
	n=read();
	for(int i=1;i<=n;i++) a[i]=read();
	for(int i=1;i<=n;i++) p[i]=read();
	n++,a[n]=n;
	m=read();
	for(int i=1;i<=m;i++) b[i]=read();
    for(int i=1;i<=m;i++) pos[b[i]]=i;
    m++,b[m]=n,pos[n]=m;
	for(int i=1;i<N;i++) g[i]=INF;
    for (int i=1,j=1;i<=n;++i)
	{
		if(b[j]<i) j++;
		Q[i]=j;
	}
	ll tmp=0;
    for (int i=1;i<=n;i++)
    {
		int j=pos[a[i]];
		if(j)
		{
			if(g[j-1]>=INF) f[i]=INF;
			else f[i]=g[j-1]+cost[j]+sum;
		}
		if(p[i]>=0) cost[Q[a[i]]]+=p[i];
		else sum+=p[i];
		if(j&&f[i]<INF) g[j]=min(g[j],f[i]-sum);
	}
    if (f[n]<INF) printf("%lld",f[n]);
    else puts("Impossible");
    return 0;
}