Centroid Decomposition

#include<bits/stdc++.h>
#define ll long long

// NOTE:
//      ONLY 1-indexed is acceptable
//  Usage: 
    //      CentroidDecomposition cent;
    //      cent.init();
    //      cent.addEdge(u, v);
    //      cent.getCentroid(root) and/or cent.buildCentroidTree();
    //      cent.prepareLCA() if needed
    //      cent.cenPa[u] => parent of u in cen-tree.
    //      cenAdj[u] store all nodes connecting to u in cen-tree (including parent of u)
struct CentroidDecomposition 
{       
    struct Edge {
        int id;
        int v;
        Edge (int _1, int _2) : id(_1), v(_2) {}
    };
    
    int n;
    int root;
    int *sz;
    set<int> *eadj;
    vector<Edge> edges;
    int *inuse;
    
    int cenRoot;
    int *cenPa; // NOTE: parent of cenRoot = cenRoot
    vector<int>* cenAdj;
    
    int lgn;
    int *depth;
    int **ancestor;
    
    void init(int n_) {
        n = n_;
        sz = new int[n+1];
        inuse = new int[n+1];
        cenPa = new int[n+1];
        eadj = new set<int>[n+1];
        
        fill(sz, sz + n + 1, 0);
        fill(inuse, inuse + n + 1, 0);
        fill(cenPa, cenPa + n + 1, 0);
        
        edges.clear();
    }
    
    void addEdge(int u, int v) {
        
        eadj[u].insert(edges.size());
        edges.push_back(Edge(edges.size(), v));
        
        eadj[v].insert(edges.size());
        edges.push_back(Edge(edges.size(), u));
        
        inuse[u] = inuse[v] = 1;
    }
    
    int getCentroid(int root=-1) {
        if (root < 0)
            for (int i = 0; i <= n; ++i)
                if (inuse[i]) {
                    root = i;
                    break;
                }
        
        computeSubtreeSize(root, -1);
        
        int centroid = systemGetCentroid(root, -1, sz[root]);
        return centroid;
    }
    
    void buildCentroidTree(int root=-1) {       
        fill(cenPa, cenPa + n + 1, 0); // parent in centroid decomposition. cenPa[x] != 0 <=> x is already in cen-decom.
        cenAdj = new vector<int>[n+1];
        
        queue<int> que;
        
        cenRoot = getCentroid(root);
        cenPa[cenRoot] = cenRoot;
        que.push(cenRoot);
        
        while(que.size()) {
            int u = que.front();
            que.pop();
            
            for (auto i : eadj[u]) {
                int v = edges[i].v;
                if (!cenPa[v]) {
                    int cenV = getCentroid(v);
                    cenPa[cenV] = u;
                    que.push(cenV);
                    cenAdj[u].push_back(cenV);
                    cenAdj[cenV].push_back(u);
                }
            }
        }
        
    }
        
    void computeSubtreeSize(int u, int pa) {
        sz[u] = 1;
        
        for (auto i : eadj[u]) {
            int v = edges[i].v;
            if (v != pa && !cenPa[v]) {
                computeSubtreeSize(v, u);
                sz[u] += sz[v];
            }
        }
    }
    
    int systemGetCentroid(int u, int pa, int full_size) {
        for (auto i : eadj[u]) {
            int v = edges[i].v;
            if (v != pa && !cenPa[v]) 
                if (sz[v] > (full_size/2))
                    return systemGetCentroid(v, u, full_size);
        }
        return u;
    }
    
    // lca-part
    // ====
    
    void lcadfs(int u, int pa, int d) {
        depth[u] = d;
        
        ancestor[u][0] = pa;
        for (int j = 1; j <= lgn; ++j)
            ancestor[u][j] = ancestor[ancestor[u][j-1]][j-1];
        
        for (auto i: eadj[u])
            if (edges[i].v != pa)
                lcadfs(edges[i].v, u, d+1);
    }
    
    void prepareLCA() {
        lgn = log2(n) + 1;
        ancestor = new int*[n+1];
        for (int i = 0; i <= n; ++i)
            ancestor[i] = new int[lgn+1];
        
        for (int j = 0; j <= lgn; ++j)
            ancestor[0][j] = 0;
        
        depth = new int[n+1];
        depth[0] = 0;
        
        lcadfs(1, 0, 0);
        
    }
    
    int getLCA(int u, int v) {
        
        if (depth[u] < depth[v])
            swap(u, v);
        
        // balance depth of u and v
        for (int j = lgn; j >= 0; --j)
            if (depth[ancestor[u][j]] >= depth[v])
                u = ancestor[u][j];
        
        // if 1 of them is ancestor of the other, return result
        if (u == v)
            return u;
        
        // find the highest non-common ancestors of u and v
        for (int j = lgn; j >= 0; --j)
            if (ancestor[u][j] != ancestor[v][j]) {
                u = ancestor[u][j];
                v = ancestor[v][j];
            }
        // return
        return ancestor[u][0];
        
    }
    
    // ==== end lca-part
    
    void print() {
        cout << "root centroid: " << cenRoot << endl;
        for (int i = 0; i <= n; ++i)
            for (int v : cenAdj[i]) if (v != cenPa[i])
                cout << i << " " << v << endl;
    }
};

Leave a Reply