2015年7月11日 星期六

[POI 13 Stage 3] Palindromes

作法:

考慮把兩個不同的字串$S_1,S_2$接起來會形成回文,不妨設$|S_1|\leq |S_2|$,那麼不難推得他是回文的充要條件為:$S_1$是$S_2$的前綴,並且$S_2$的長度為$|S_2|-|S_1|$的前綴也是回文。因此我們可以考慮先建一個 trie ,那麼就可以得出:對每個字串$S$來說,有幾個字串恰好為$S$的長度為$i$的前綴。並且再用 manacher 處理一個字串的子字串是否為回文的詢問就可以了。但直接建 trie 會 MLE ,因為一個 node 存 26 個子節點太浪費了,於是我苦苦的把子節點的紀錄方式改成用 treap 才過,詳細就參考 code 吧。

code :



#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int maxn=2000000+10 ;
 
struct manacher
{
    char *s ;
    int n , *Z ;
    int extl(int l,int r)
    {
        int i=1 ;
        for(;l-i+1>=0 && r+i-1<=2*n-2 && s[l-i+1]==s[r+i-1];i++) ;
        return i-1 ;
    }
    void init(char *t)
    {
        n=strlen(t) ;
        s=new char[2*n+2] ;
        Z=new int[2*n+2] ;
        for(int i=0;i<2*n-1;i++) s[i]=(i%2 ? '$' : t[i/2]) ;
        for(int i=0,L=-1,R=-1;s[i];i++)
        {
            int i2=L+R-i ;
            if(R<i) Z[i]=extl(i,i) ;
            else if(Z[i2]!=R-i+1) Z[i]=min(Z[i2],R-i+1) ;
            else Z[i]=extl(2*i-R,R)+(R-i) ;
            if(i+Z[i]-1>=R) L=i-Z[i]+1 , R=i+Z[i]-1 ;
        }
    }
    bool check(int l,int r)
    {
        return Z[r+l]>=r-l+1 ;
    }
}man[maxn];
 
struct node
{
    node *l,*r ;
    int c ; int val,fix ;
    node(int _c,int _val){c=_c ; val=_val ; fix=rand() ;}
};
 
node *merge(node *a,node *b)
{
    if(!a || !b) return a ? a : b ;
    if(a->fix<b->fix)
    {
        node *u=a ;
        u->r=merge(a->r,b) ;
        return u ;
    }
    else
    {
        node *u=b ;
        u->l=merge(a,b->l) ;
        return u ;
    }
}
void split(node *u,node *&a,node *&b,int k)
{
    if(!u){a=b=NULL ; return ;}
    if(u->c <= k)
    {
        a=u ;
        split(u->r,a->r,b,k) ;
    }
    else
    {
        b=u ;
        split(u->l,a,b->l,k) ;
    }
}
 
int ccnt ;
node *root[maxn] ;
int find(node *&u,int c)
{
    if(!u){u=new node(c,++ccnt) ; return ccnt ; }
    if(u->c == c) return u->val ;
    else if(u->c < c) return find(u->l,c) ;
    else return find(u->r,c) ;
}
 
int val[maxn] ;
void insert(char *t)
{
    int now=0 ;
    for(int i=0;t[i];i++)
    {
        int c=t[i]-'a' ;
        now=find(root[now],c) ;
    }
    val[now]++ ;
}
 
char *s[maxn] ;
int len[maxn] ;
main()
{
    srand(time(NULL)) ;
    int n ; scanf("%d",&n) ;
    for(int i=0;i<n;i++)
    {
        int m ; scanf("%d",&m) ;
        s[i]=new char[m+5] ; scanf("%s",s[i]) ;
        man[i].init(s[i]) ;
        insert(s[i]) ;
        len[i]=man[i].n ;
    }
    LL ans=0 ;
    for(int i=0;i<n;i++) for(int j=0,now=0;j<len[i];j++)
    {
        now=find(root[now],s[i][j]-'a') ;
        int mul=(j==len[i]-1 ? 1 : man[i].check(0,len[i]-j-2)) ;
        ans+=2*mul*val[now] ;
    }
    ans-=n ;
    for(int i=1;i<=ccnt;i++) ans-=(LL)val[i]*(val[i]-1)/2 ;
    printf("%lld\n",ans) ;
}
 

沒有留言:

張貼留言