点分治入门题
首先可以直接枚举所有两点的lca强行dp
设 $f [ x ] [ 0/1/2 ]$ 表示节点 $x$ 在模3意义下,$x$ 的子树所有节点到 $x$ 的距离为 $0/1/2$ 时的方案数
初始 $f [ x ] [ 0 ] =1$ (本身到自己有一种方案)
转移就枚举所有儿子 $v$ ,设 $x$ 到 $v$ 的距离为 $w$,那么转移显然为:
$ f [ x ] [ (w+0)\%3 ] += f [ v ] [ 0 ] $
$ f [ x ] [ (w+1)\%3 ] += f [ v ] [ 1 ] $
$ f [ x ] [ (w+2)\%3 ] += f [ v ] [ 2 ] $
统计答案也十分显然,对 $w$ 分类讨论一下就好了:
inline void work(int x)//注意函数名不是"dfs",x就是我们枚举的lca{ f[x][0]=1; f[x][1]=f[x][2]=0; for(int i=fir[x];i;i=from[i]) { int &v=to[i],&w=val[i]; if(vis[v]) continue; dfs(v,x);//dfs求出儿子的f if(w==0) ans+=f[x][0]*f[v][0]+f[x][1]*f[v][2]+f[x][2]*f[v][1]; if(w==1) ans+=f[x][0]*f[v][2]+f[x][1]*f[v][1]+f[x][2]*f[v][0]; if(w==2) ans+=f[x][0]*f[v][1]+f[x][1]*f[v][0]+f[x][2]*f[v][2]; //注意先统计ans再转移f f[x][w]+=f[v][0]; f[x][fk(w+1)]+=f[v][1]; f[x][fk(w+2)]+=f[v][2]; }}
但是最坏情况会被卡到 $O(n^2)$
所以上点分治,每次找重心作lca,这样每次子树大小至少减半
枚举lca复杂度$O(n)$,搞dp因为子树大小每次减半所以复杂度约为 $O(log_n)$
总复杂度 $O(nlog_n)$
注意long long
#include#include #include #include #include #include using namespace std;typedef long long ll;inline int read(){ int x=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); } return x*f;}const int N=5e5+7,INF=1e9+7;int fir[N],from[N<<1],to[N<<1],val[N<<1],cntt;inline void add(int &a,int &b,int &c){ from[++cntt]=fir[a]; fir[a]=cntt; to[cntt]=b; val[cntt]=c;}inline int fk(int x) { return x>=3 ? x-3 : x; }int n,rt,tot;ll ans,f[N][3];int sz[N],mx[N];bool vis[N];void find_rt(int x,int fa)//找重心{ mx[x]=0; sz[x]=1; for(int i=fir[x];i;i=from[i]) { int &v=to[i]; if(vis[v]||v==fa) continue; find_rt(v,x); sz[x]+=sz[v]; mx[x]=max(mx[x],sz[v]); } mx[x]=max(mx[x],tot-sz[x]); if(mx[x]
其实此题不用点分治
可以直接树形dp,转移同上...代码又短又好写
#include#include #include #include #include #include using namespace std;typedef long long ll;inline int read(){ int x=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); } return x*f;}const int N=5e5+7;inline int fk(int x) { return x>=3 ? x-3 : x; }int fir[N],from[N<<1],to[N<<1],val[N<<1],cntt;inline void add(int &a,int &b,int &c){ from[++cntt]=fir[a]; fir[a]=cntt; to[cntt]=b; val[cntt]=c;}ll ans,f[N][3];void dfs(int x,int fa){ f[x][0]=1; for(int i=fir[x];i;i=from[i]) { int &v=to[i],&w=val[i]; if(v==fa) continue; dfs(v,x); if(w==0) ans+=f[x][0]*f[v][0]+f[x][1]*f[v][2]+f[x][2]*f[v][1]; if(w==1) ans+=f[x][0]*f[v][2]+f[x][1]*f[v][1]+f[x][2]*f[v][0]; if(w==2) ans+=f[x][0]*f[v][1]+f[x][1]*f[v][0]+f[x][2]*f[v][2]; f[x][w]+=f[v][0]; f[x][fk(w+1)]+=f[v][1]; f[x][fk(w+2)]+=f[v][2]; }}int n;ll gcd(ll a,ll b) { return b ? gcd(b,a%b) : a; }int main(){ int a,b,c; n=read(); for(int i=1;i