Heavy-Light Decomposition

#include<bits/stdc++.h>
#define ll long long

// NOTE:
//      1-INDEXED
 
// Usage:
    // HLD hld;
    // hld.init(n);
    // hld.addEdge(u, v);
    // hld.decompose(root);
    // build segment tree using inarray_inv[u]
    // hld.getLCA(u, v)
    // we can call getHLPath(u, v) to get all {from, to} path that we must walk through when travelling from u to v
 
struct HLD // Heavy Light Decomposition
{  
    int n;
    int *sz;
    int *next_heavy;
    int *depth;
    int *pa;
    int *head_heavy;
    int *tail_heavy;
    int *inarray; // inarray[u] is where node u-th resides in the array to make segment tree (tree-node to array-index)
    int *inarray_inv; // use this to build/init segment tree (array-index to tree-node)
   
    vector<int> *adj;
   
    HLD() {};
   
    void init(int n_) {
        n = n_;
 
        sz          = new int[n+1];
        next_heavy  = new int[n+1];
        depth       = new int[n+1];
        pa          = new int[n+1];
        head_heavy  = new int[n+1];
        tail_heavy  = new int[n+1];
        inarray     = new int[n+1];
        inarray_inv = new int[n+1];
        adj         = new vector<int>[n+1];
 
        fill(sz,         sz + n+1,          0);
        fill(next_heavy, next_heavy + n+1, -1);
        fill(depth,      depth + n+1,       0);
        fill(pa,         pa + n+1,          0);
        fill(head_heavy, head_heavy + n+1, -1);
        fill(tail_heavy, tail_heavy + n+1, -1);
        fill(inarray,    inarray + n+1,    -1);
        fill(inarray_inv,inarray_inv + n+1,-1);
       
       
    }
   
    void addEdge(int u, int v) {
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
   
    void dfs(int u, int p) {
        pa[u] = p;
        sz[u] = 1;
        depth[u] = depth[p] + 1;
       
        int max_size = 0;
        int best_child = -1;
       
        for (auto v : adj[u]) if (v != p) {
            dfs(v, u);
            if (sz[v] > max_size) {
                max_size = sz[v];
                best_child = v;
            }
            sz[u] += sz[v];
        }
       
        if (best_child > 0) {
            next_heavy[u] = best_child;
        }
       
    }
   
    void dfs2(int u, int h, int &cnt) {
        head_heavy[u] = h;
        inarray[u] = cnt++;
        inarray_inv[cnt-1] = u;
       
        if (next_heavy[u] > 0) {
            dfs2(next_heavy[u], h, cnt);
            tail_heavy[u] = tail_heavy[next_heavy[u]];
        }
        else {
            tail_heavy[u] = u;
        }
       
        for (auto v : adj[u])
            if (v != pa[u] && v != next_heavy[u]) {
                dfs2(v, v, cnt);
            }
       
    }
   
    void decompose(int root) {
        int cnt = 1;
        depth[root] = 1;
        dfs (root, 0);
        dfs2(root, root, cnt);
    }
   
    int getLCA(int u, int v) {
        while(u != v) {
            if (depth[u] > depth[v])
                swap(u, v);
           
            if (head_heavy[u] == head_heavy[v]) {
                v = u;
                break;
            }
           
            if (depth[head_heavy[v]] > depth[head_heavy[u]])
                v = pa[head_heavy[v]];
            else
                u = pa[head_heavy[u]];
        }
       
        return u;
    }
   
    void travelHL(int child, int parent, vector< pair<int, int> > & paths, bool include_parent) {
        // get paths when travelling from a child to its parent
       
        while(head_heavy[child] != head_heavy[parent]) {
            paths.push_back({inarray[child], inarray[head_heavy[child]]});
            child = pa[head_heavy[child]];
        }
       
        if (include_parent) {
            paths.push_back({inarray[child], inarray[parent]});
        }
        else {
            if (child != parent) {
                int x = next_heavy[parent];
                paths.push_back({inarray[child], inarray[x]});
            }
        }
       
    }
   
    vector< pair<int, int> > getHLPath(int u, int v) {
        // returns a vector of {x, y} where we should travel from index x to index y in the ARRAY
        // NOTE: we must call buildLCA() before calling this function
       
        vector< pair<int, int> > paths;
       
        int lca = getLCA(u, v);
       
        travelHL(u, lca, paths, true);
        travelHL(v, lca, paths, false);
       
        return paths;
       
    }


};
 

Leave a Reply