# SQRT Decomposition on Tree

```#include<bits/stdc++.h>
#define ll long long

// Note:
//      Vertices are 1-INDEXED
// Usage:
//      Sqrt_Decomposition_Tree sqd;
//      sqd.init(n);
//      sqd.sqrt_decompose_tree(); => will fill parent, depth, depth_index, sqnode, sqcnt
//      sqd.precompute();        => will fill info
//      sqd.prepareLCA();        => will fill ancestor
//      sqd.compute() will return your needed result
//      sqd.getLCA(u, v) return LCA
struct Sqrt_Decomposition_Tree
{

struct Edge {
int v;
int c, d;
};

Sqrt_Decomposition_Tree() {}

int n;
int sq;
int root;
Edge *edge;
int edgecnt;
int *parent;
int *iepa; // edge-index to parent
int *depth;
pair<int, int> *depth_index;
int *sqnode;
int sqcnt;

// TODO: decalre some variables to store when precompute
map<int, pair<int, int> > *color_data;
int *dist_to_root;
// ====

int lgn;
int **ancestor;

void init(int n_) {
n = n_;
sq = sqrt(n);

edge        = new Edge[n*2+2];
parent      = new int[n+1];
iepa        = new int[n+1];
depth       = new int[n+1];
depth_index = new pair<int, int> [n+1];
sqnode      = new int[n+1];

fill(depth,  depth + n + 1, -1);
fill(sqnode, sqnode + n + 1, 0);
for (int i = 1; i <= n; ++i)

edgecnt = 0;
}

void addEdge(int u, int v, int c, int d) {
edgecnt++;
edge[edgecnt].v = v;
edge[edgecnt+n].v = u;
edge[edgecnt].c = edge[edgecnt+n].c = c;
edge[edgecnt].d = edge[edgecnt+n].d = d;
}

void dfs(int u, int pa) { // find parent, depth and iepa
parent[u] = pa;

if (pa == 0)
depth[u] = 0;
else
depth[u] = depth[pa] + 1;

for (auto i : eadj[u]) {
int v = edge[i].v;
if (v != pa) {
iepa[v] = i < n ? i+n : i-n;
dfs(v, u);
}
}
}

void markCoreNode(int u) {

for (int i = 0; i < sq; ++i) {
if (sqnode[u]) return;
if (parent[u] <= 0) return;
u = parent[u];
}

sqnode[u] = ++sqcnt;

}

int sqrt_decompose_tree(int root_) {

root = root_;
sqcnt = 0;

parent[0] = 0;
iepa[0] = 0;
dfs(1, 0);

for (int i = 1; i <= n; ++i) {
depth_index[i].first = depth[i];
depth_index[i].second = i;
}

sort(depth_index + 1, depth_index + n + 1);
for (int i = n; i >= 1; --i)
markCoreNode(depth_index[i].second);

return sqcnt;
}

// TODO: implement these methods
void idfs(int u) {

int sqidx = sqnode[u];
int v = u;
while(parent[v] > 0) {

int i = iepa[v];
if (color_data[sqidx].count(edge[i].c)) {
pair<int, int> pi = color_data[sqidx][edge[i].c];
pi.first++;
pi.second += edge[i].d;
color_data[sqidx][edge[i].c] = pi;
}
else {
color_data[sqidx][edge[i].c] = {1, edge[i].d};
}

v = parent[v];

if (sqnode[v]) {

for (auto it = color_data[sqnode[v]].begin(); it != color_data[sqnode[v]].end(); ++it) {
if (color_data[sqidx].count(it->first)) {
auto pi = color_data[sqidx][it->first];
pi.first += it->second.first;
pi.second += it->second.second;
color_data[sqidx][it->first] = pi;
}
else {
color_data[sqidx][it->first] = it->second;
}
}

break;
}
}

}

void distdfs(int u, int d) {
dist_to_root[u] = d;

if (edge[i].v != parent[u])
distdfs(edge[i].v, d + edge[i].d);
}
// ====

void precompute() {
// TODO: init variables to store when precompute
// WARNING: use map instead of array to store values (to prevent Memore Exceed)
color_data  = new map<int, pair<int, int> >[sqcnt+1];

dist_to_root = new int[n+1];
fill(dist_to_root, dist_to_root + n + 1, 0);
// ====

// TODO: computation depending of specific problem
for (int i = 1; i <= n; ++i) {
int u = depth_index[i].second;
if (sqnode[u])
idfs(u);
}

distdfs(root, 0);
// ====
}

// TODO: implement this method
int compute(int u, int c, int x) {
int curr_dist = dist_to_root[u];
int cnt = 0;
int sum = 0;

while(u != root && u > 0) {

if (sqnode[u]) {
if (color_data[sqnode[u]].count(c)) {
cnt += color_data[sqnode[u]][c].first;
sum += color_data[sqnode[u]][c].second;
}
break;
}

int i = iepa[u];
if (edge[i].c == c) {
cnt++;
sum += edge[i].d;
}

u = parent[u];
}

int res = curr_dist + (cnt*x - sum);
return res;

}
// ====

void lcadfs(int u, int pa) {
ancestor[u][0] = pa;
for (int j = 1; j <= lgn; ++j)
ancestor[u][j] = ancestor[ancestor[u][j-1]][j-1];

if (edge[i].v != pa)
}

void prepareLCA() {
lgn = log2(n) + 1;
ancestor = new int*[n+1];
for (int i = 0; i <= n; ++i)
ancestor[i] = new int[lgn+1];

for (int j = 0; j <= lgn; ++j)
ancestor[0][j] = 0;

}

int getLCA(int u, int v) {

if (depth[u] < depth[v])
swap(u, v);

// balance depth of u and v
for (int j = lgn; j >= 0; --j)
if (depth[ancestor[u][j]] >= depth[v])
u = ancestor[u][j];

// if 1 of them is ancestor of the other, return result
if (u == v)
return u;

// find the highest non-common ancestors of u and v
for (int j = lgn; j >= 0; --j)
if (ancestor[u][j] != ancestor[v][j]) {
u = ancestor[u][j];
v = ancestor[v][j];
}
// return
return ancestor[u][0];

}

};
```