2015年3月17日 星期二

[TIOJ 1696] Problem F 橘子園保衛戰

這題聽吳仲昇講解才會QQ

作法:

以下記每個點的保衛值為 val[ x ] 。
考慮樹分治,對於一個子樹,先選出他的重心當根,把它叫做 x 好了,然後我們的目的是要把根拔掉,分出好幾顆子樹,再分別做這些子樹。所以第一個我們需要把根的答案加上「這顆子樹中和他距離 <= val[ x ] 的點數個數 」,並且因為接下來要把各個子樹拆開處理,假設子樹們為 T1 , ... , Tk ,那麼對於 T1 裡的某個點 P 而言,在 T1 裡面並且和他距離 <= val[ P ] 的個數會在計算子問題 ( T1 ) 的時候被算完,而 T2 ~ Tk 裡面的點對 P 的貢獻在之後就不會被算到了,所以這時候必須把 T2 ~ Tk ( 還有 x  )提供的和 P 距離 <= val[ P ] 的點個數先加到 P 的答案裡面。

所以現在的問題變成要如何對每個子樹計算那個值。設 T1 中的點 P 的深度為 d[ P ] ( 其中 x 的 d 值是 0 ),那麼要算的 P 的答案就是用「以 x 為根的子樹中和 P 的距離 <= val[ P ] - d[ P ] 」的點數量,扣掉「 T1 中和 T1 的根的距離 <= val[ P ] - d[ P ] - 1 」 的點數量,所以我們需要先 DFS 求出以 x 為根的子樹中距離 <= 某個值 k 的點數量有多少( 在我的 code 裡是用 sum 陣列紀錄 ),並且在算下一個子樹 Ti 裡的所求值時候,先算出 Ti 中到 Ti 的根的距離 <= 某個值 k 的點數量有多少 ( 用 sum2 陣列紀錄 ),就可以算出每個點要加多少值了。

但如果在每次做樹分治的時候都對 sum 陣列和 sum2 陣列直接用 memset 歸零 ,那麼時間就會爆掉。但事實上當某個子樹的大小 = size 的時候,到子樹的根的距離頂多是 size ,也就是 sum ( 或 sum2 ) 陣列其實從某一個點之後他的值就不在變了,所以當我們在求 sum 陣列的時候,第一步是先算出「到子樹的根的距離恰為 k 的樹有幾個」,再做他的前綴和,所以只需要把前 size 個都設成0就好了 ( 用 fill 函式 )。但這樣要注意到,之後要查詢 sum 陣列裡的值的時候,如果查詢的 index > size ,那麼這次查詢的答案就會是 sum[ size ] ,因為 sum[ size ] 以後的值都是錯的,我們只有維護 0 ~ size 的值,所以要特判。

code :

#include<bits/stdc++.h>
using namespace std;
const int maxn=100000+10 ;
 
int val[maxn] ;
vector<int> v[maxn] ;
bool vis[maxn] ;
int sum[maxn],sum2[maxn],sz1,sz2 ;
int cnt ;
 
int d[maxn] ;
void dfs0(int x,int &M,int dep)
{
    d[x]=dep ; vis[x]=1 ; cnt++ ;
    if(d[x]>d[M]) M=x ;
    for(auto i : v[x]) if(!vis[i])
        dfs0(i,M,dep+1) ;
    vis[x]=0 ;
}
 
int get_cent(int x,int &sz)
{
    int y=x ; cnt=0 ;
    dfs0(x,y,0) ; sz=cnt ;
    x=y ;
    dfs0(y,x,0) ;
 
    int maxd=d[x] ;
    if(!maxd) return x ;
    for(int i=x;;)
    {
        for(auto j : v[i]) if(!vis[j] && d[j]==d[i]-1)
            { i=j ; break ; }
        if(d[i]==maxd/2) return i ;
    }
}
 
void dfs_dis(int x,int dep,int *sm)
{
    sm[dep]++ ;
    vis[x]=1 ;
    for(auto i : v[x]) if(!vis[i])
        dfs_dis(i,dep+1,sm) ;
    vis[x]=0 ;
}
 
int ans[maxn] ;
 
void dfs_cal(int x,int dep)
{
    vis[x]=1 ;
    for(auto i : v[x]) if(!vis[i])
        dfs_cal(i,dep+1) ;
    vis[x]=0 ;
 
    int val2=val[x]-dep-1 ;
    if(val2>=0) ans[x]+= ( val2<=sz1 ? sum[val2] : sum[sz1] ) ,
                ans[x]-= (val2<=sz2 ? sum2[val2] : sum2[sz2]) ;
}
 
void solve(int y)
{
    int x=get_cent(y,sz1) ;
 
    fill(sum,sum+sz1+1,0) ;
    dfs_dis(x,0,sum) ;
    for(int i=1;i<=sz1;i++) sum[i]+=sum[i-1] ;
    ans[x]+= (val[x]<=sz1 ? sum[val[x]] : sum[sz1]) ;
 
    vis[x]=1 ;
    for(auto i : v[x]) if(!vis[i])
    {
        get_cent(i,sz2) ;
        fill(sum2,sum2+sz2+1,0) ;
        dfs_dis(i,1,sum2) ;
        for(int i=1;i<=sz2;i++) sum2[i]+=sum2[i-1] ;
        dfs_cal(i,0) ;
    }
    for(auto i : v[x]) if(!vis[i])
        solve(i) ;
}
 
main()
{
    int n ; scanf("%d",&n) ;
    for(int i=1;i<=n;i++) scanf("%d",&val[i]) ;
    for(int i=1;i<n;i++)
    {
        int x,y ; scanf("%d%d",&x,&y) ;
        v[x].push_back(y) ;
        v[y].push_back(x) ;
    }
    solve(1) ;
    for(int i=1;i<=n;i++) printf("%d\n",ans[i]) ;
}
 

沒有留言:

張貼留言