SQRT Decomposition on Tree

#include<bits/stdc++.h>
#define ll long long

// Note:
//      Vertices are 1-INDEXED
// Usage:
//      Sqrt_Decomposition_Tree sqd;
//      sqd.init(n);
//      sqd.addEdge(...);
//      sqd.sqrt_decompose_tree(); => will fill parent, depth, depth_index, sqnode, sqcnt
//      sqd.precompute();        => will fill info
//      sqd.prepareLCA();        => will fill ancestor
//      sqd.compute() will return your needed result
//      sqd.getLCA(u, v) return LCA
struct Sqrt_Decomposition_Tree 
{
    
    struct Edge {
        int v;
        int c, d;
    };
    
    Sqrt_Decomposition_Tree() {}
    
    int n;
    int sq;
    int root;
    Edge *edge;
    int edgecnt;
    int *parent;
    int *iepa; // edge-index to parent
    int *depth;
    pair<int, int> *depth_index;
    vector<int> *eadj;
    int *sqnode;
    int sqcnt;
    
    // TODO: decalre some variables to store when precompute
    map<int, pair<int, int> > *color_data;
    int *dist_to_root;
    // ====
    
    int lgn;
    int **ancestor;
    
    void init(int n_) {
        n = n_;
        sq = sqrt(n);
        
        edge        = new Edge[n*2+2];
        parent      = new int[n+1];
        iepa        = new int[n+1];
        depth       = new int[n+1];
        depth_index = new pair<int, int> [n+1];
        eadj        = new vector<int>[n+1];
        sqnode      = new int[n+1];     
        
        fill(depth,  depth + n + 1, -1);
        fill(sqnode, sqnode + n + 1, 0);
        for (int i = 1; i <= n; ++i)
            eadj[i].clear();
        
        edgecnt = 0;
    }
    
    void addEdge(int u, int v, int c, int d) {
        edgecnt++;
        edge[edgecnt].v = v;
        edge[edgecnt+n].v = u;
        edge[edgecnt].c = edge[edgecnt+n].c = c;
        edge[edgecnt].d = edge[edgecnt+n].d = d;
        eadj[u].push_back(edgecnt);
        eadj[v].push_back(edgecnt+n);
    }
    
    void dfs(int u, int pa) { // find parent, depth and iepa
        parent[u] = pa;
        
        if (pa == 0)
            depth[u] = 0;
        else
            depth[u] = depth[pa] + 1;
        
        for (auto i : eadj[u]) {
            int v = edge[i].v;
            if (v != pa) {
                iepa[v] = i < n ? i+n : i-n;
                dfs(v, u);
            }
        }
    }
    
    void markCoreNode(int u) {
        
        for (int i = 0; i < sq; ++i) {
            if (sqnode[u]) return;
            if (parent[u] <= 0) return;
            u = parent[u];
        }
        
        sqnode[u] = ++sqcnt;
        
    }
    
    int sqrt_decompose_tree(int root_) {
        
        root = root_;
        sqcnt = 0;
        
        parent[0] = 0;
        iepa[0] = 0;
        dfs(1, 0);
        
        for (int i = 1; i <= n; ++i) {
            depth_index[i].first = depth[i];
            depth_index[i].second = i;
        }
        
        sort(depth_index + 1, depth_index + n + 1);
        for (int i = n; i >= 1; --i)
            markCoreNode(depth_index[i].second);
        
        return sqcnt;
    }
    
    // TODO: implement these methods
    void idfs(int u) {
        
        int sqidx = sqnode[u];
        int v = u;
        while(parent[v] > 0) {
            
            int i = iepa[v];
            if (color_data[sqidx].count(edge[i].c)) {
                pair<int, int> pi = color_data[sqidx][edge[i].c];
                pi.first++;
                pi.second += edge[i].d;
                color_data[sqidx][edge[i].c] = pi;
            }
            else {
                color_data[sqidx][edge[i].c] = {1, edge[i].d};
            }
                
            v = parent[v];
            
            if (sqnode[v]) {
                
                for (auto it = color_data[sqnode[v]].begin(); it != color_data[sqnode[v]].end(); ++it) {
                    if (color_data[sqidx].count(it->first)) {
                        auto pi = color_data[sqidx][it->first];
                        pi.first += it->second.first;
                        pi.second += it->second.second;
                        color_data[sqidx][it->first] = pi;
                    }
                    else {
                        color_data[sqidx][it->first] = it->second;
                    }
                }
                
                break;
            }
        }
        
    }
    
    void distdfs(int u, int d) {
        dist_to_root[u] = d;
        
        for (auto i : eadj[u])
            if (edge[i].v != parent[u])
                distdfs(edge[i].v, d + edge[i].d);
    }
    // ====
    
    void precompute() {
        // TODO: init variables to store when precompute
        // WARNING: use map instead of array to store values (to prevent Memore Exceed)
        color_data  = new map<int, pair<int, int> >[sqcnt+1];
        
        dist_to_root = new int[n+1];
        fill(dist_to_root, dist_to_root + n + 1, 0);
        // ====
        
        
        // TODO: computation depending of specific problem
        for (int i = 1; i <= n; ++i) {
            int u = depth_index[i].second;  
            if (sqnode[u])
                idfs(u);
        }
        
        distdfs(root, 0);
        // ====
    }
    
    // TODO: implement this method
    int compute(int u, int c, int x) {
        int curr_dist = dist_to_root[u];
        int cnt = 0;
        int sum = 0;
        
        while(u != root && u > 0) {
                        
            if (sqnode[u]) {
                if (color_data[sqnode[u]].count(c)) {
                    cnt += color_data[sqnode[u]][c].first;
                    sum += color_data[sqnode[u]][c].second;
                }
                break;              
            }
            
            int i = iepa[u];
            if (edge[i].c == c) {
                cnt++;
                sum += edge[i].d;
            }
            
            u = parent[u];
        }
        
        int res = curr_dist + (cnt*x - sum);
        return res;
        
    }
    // ====
    
    void lcadfs(int u, int pa) {
        ancestor[u][0] = pa;
        for (int j = 1; j <= lgn; ++j)
            ancestor[u][j] = ancestor[ancestor[u][j-1]][j-1];
        
        for (auto i: eadj[u])
            if (edge[i].v != pa)
                lcadfs(edge[i].v, u);
    }
    
    void prepareLCA() {
        lgn = log2(n) + 1;
        ancestor = new int*[n+1];
        for (int i = 0; i <= n; ++i)
            ancestor[i] = new int[lgn+1];
        
        for (int j = 0; j <= lgn; ++j)
            ancestor[0][j] = 0;
            
        lcadfs(root, 0);
        
    }
    
    int getLCA(int u, int v) {
        
        if (depth[u] < depth[v])
            swap(u, v);
        
        // balance depth of u and v
        for (int j = lgn; j >= 0; --j)
            if (depth[ancestor[u][j]] >= depth[v])
                u = ancestor[u][j];
        
        // if 1 of them is ancestor of the other, return result
        if (u == v)
            return u;
        
        // find the highest non-common ancestors of u and v
        for (int j = lgn; j >= 0; --j)
            if (ancestor[u][j] != ancestor[v][j]) {
                u = ancestor[u][j];
                v = ancestor[v][j];
            }
        // return
        return ancestor[u][0];
        
    }
    
};

Leave a Reply