iinattr

虚树入门经典

2022-07-24

虚树入门经典

简介

对于某些树上问题 我们时常需要遍历整颗树
然而当这颗树过大时 遍历整棵树就会很慢
因此我们可以只选取某些对答案有影响的关键节点来遍历
本质上是通过倍增减少遍历的时间开销

例题

P2495[SDOI2011]消耗战
(本题板子)
可以发现 本题是非常水的树形dp
然而 询问有$5*10^5$次
发现本题实际上在不改变父子关系与兄弟时只关心路径最小值
并且关键点相对较少
所以我们可以把每次询问的关键点单独选出 并把关键点两两间的LCA和关键点加入一颗独立于原树之上的虚树中
使用倍增维护LCA及父路径最小值即可

高效维护虚树

显然 直接使用两两间LCA过慢 可以通过栈来维护这颗树
栈中存放了在一条链上的点
依次讲关键点加入其中
当关键点与栈顶的LCA就是栈顶时
说明它们处于一条链上
直接将该关键点压入栈中即可
若关键点与栈顶不在一条链上
说明二者LCA出现分叉
不断取出栈中元素直到可将关键点压入即可
另外 若二者LCA不在栈中
需要在合适的位置将LCA压入栈
最后 当所有关键点都入栈后
取出栈内所有元素并在虚树中连边即可

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include<iostream>
#include<vector>
#include<stack>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdio>
using namespace std;
#define N 500005
int n,m,q[N],fa[N][20],g[N][20],dfn[N],rnd[N],cnt=0,dep[N];
vector<int> mp[N],mp1[N];
long long h[N],d[N];
struct Node
{
int u,dfn;
}keys[N];
bool cmp(Node a,Node b)
{
return a.dfn<b.dfn;
}
struct node
{
int v,w;
}edge[N<<2],vt[N<<2];
int sta[N<<3];
int cnt_=0;
void dfs(int x)
{
dep[x]=dep[fa[x][0]]+1;
dfn[x]=++cnt;
rnd[cnt]=x;
for(int i=1;i<=19;i++)
{
fa[x][i]=fa[fa[x][i-1]][i-1];
g[x][i]=min(g[fa[x][i-1]][i-1],g[x][i-1]);
}
for(int i=0;i<mp[x].size();i++)
{
if(edge[mp[x][i]].v==fa[x][0]) continue;
fa[edge[mp[x][i]].v][0]=x;
g[edge[mp[x][i]].v][0]=edge[mp[x][i]].w;
dfs(edge[mp[x][i]].v);
}
}
long long lca(int u,int v)
{
if(dep[u]<dep[v]) swap(u,v);
while(dep[u]>dep[v])
{
u=fa[u][(int)log2(dep[u]-dep[v])];
}
if(u==v) return u;
for(int i=log2(dep[u]);i>=0;i--)
if(fa[u][i]!=fa[v][i])
u=fa[u][i],v=fa[v][i];
return fa[u][0];
}
long long query(int u, int v) {
int ans = 0x3f3f3f3f;
while (dep[u] > dep[v]) {
ans = min(ans, g[u][(int)log2(dep[u] - dep[v])]);
u = fa[u][(int)log2(dep[u] - dep[v])];
}
return ans;
}
void dp(int u)
{
for(int i=0;i<mp1[u].size();i++)
{
int v=vt[mp1[u][i]].v;
int w=vt[mp1[u][i]].w;
dp(v);
if(h[v])
d[u]+=w;
else
d[u]+=min((long long)w,d[v]);
h[v]=0;d[v]=0;
}
mp1[u].clear();
}
int main()
{
cin>>n;
memset(g,0x3f,sizeof(g));
for(int i=1;i<n;i++)
{
int u,v,w;
cin>>u>>v>>w;
mp[u].push_back(i*2-1);
edge[i*2-1].v=v;edge[i*2-1].w=w;
mp[v].push_back(i*2);
edge[i*2].v=u;edge[i*2].w=w;
}
dfs(1);
scanf("%d",&m);
while(m--)
{
int k;
scanf("%d",&k);
for(int i=1;i<=k;i++)
{
cin>>keys[i].u;
h[keys[i].u]=1;
keys[i].dfn=dfn[keys[i].u];
}
stack<int> s;
sort(keys+1,keys+k+1,cmp);
s.push(1);
for(int i=1;i<=k;i++)
{
int u=keys[i].u;
int la=lca(u,s.top());
while(s.top()!=la)
{
int tmp=s.top();s.pop();
if(dfn[s.top()]<dfn[la])
s.push(la);
mp1[s.top()].push_back(++cnt_);
vt[cnt_].v=tmp;
vt[cnt_].w=query(tmp,s.top());
}
s.push(u);
}
while(s.top()!=1)
{
int tmp=s.top();
s.pop();
mp1[s.top()].push_back(++cnt_);
vt[cnt_].v=tmp;
vt[cnt_].w=query(tmp,s.top());
}
dp(1);
cout<<d[1]<<endl;
d[1]=0;
cnt_=0;
}
return 0;
}