#include<bits/stdc++.h>
#define ll long long
// Note:
// Vertices are 1-INDEXED
// Usage:
// Sqrt_Decomposition_Tree sqd;
// sqd.init(n);
// sqd.addEdge(...);
// sqd.sqrt_decompose_tree(); => will fill parent, depth, depth_index, sqnode, sqcnt
// sqd.precompute(); => will fill info
// sqd.prepareLCA(); => will fill ancestor
// sqd.compute() will return your needed result
// sqd.getLCA(u, v) return LCA
struct Sqrt_Decomposition_Tree
{
struct Edge {
int v;
int c, d;
};
Sqrt_Decomposition_Tree() {}
int n;
int sq;
int root;
Edge *edge;
int edgecnt;
int *parent;
int *iepa; // edge-index to parent
int *depth;
pair<int, int> *depth_index;
vector<int> *eadj;
int *sqnode;
int sqcnt;
// TODO: decalre some variables to store when precompute
map<int, pair<int, int> > *color_data;
int *dist_to_root;
// ====
int lgn;
int **ancestor;
void init(int n_) {
n = n_;
sq = sqrt(n);
edge = new Edge[n*2+2];
parent = new int[n+1];
iepa = new int[n+1];
depth = new int[n+1];
depth_index = new pair<int, int> [n+1];
eadj = new vector<int>[n+1];
sqnode = new int[n+1];
fill(depth, depth + n + 1, -1);
fill(sqnode, sqnode + n + 1, 0);
for (int i = 1; i <= n; ++i)
eadj[i].clear();
edgecnt = 0;
}
void addEdge(int u, int v, int c, int d) {
edgecnt++;
edge[edgecnt].v = v;
edge[edgecnt+n].v = u;
edge[edgecnt].c = edge[edgecnt+n].c = c;
edge[edgecnt].d = edge[edgecnt+n].d = d;
eadj[u].push_back(edgecnt);
eadj[v].push_back(edgecnt+n);
}
void dfs(int u, int pa) { // find parent, depth and iepa
parent[u] = pa;
if (pa == 0)
depth[u] = 0;
else
depth[u] = depth[pa] + 1;
for (auto i : eadj[u]) {
int v = edge[i].v;
if (v != pa) {
iepa[v] = i < n ? i+n : i-n;
dfs(v, u);
}
}
}
void markCoreNode(int u) {
for (int i = 0; i < sq; ++i) {
if (sqnode[u]) return;
if (parent[u] <= 0) return;
u = parent[u];
}
sqnode[u] = ++sqcnt;
}
int sqrt_decompose_tree(int root_) {
root = root_;
sqcnt = 0;
parent[0] = 0;
iepa[0] = 0;
dfs(1, 0);
for (int i = 1; i <= n; ++i) {
depth_index[i].first = depth[i];
depth_index[i].second = i;
}
sort(depth_index + 1, depth_index + n + 1);
for (int i = n; i >= 1; --i)
markCoreNode(depth_index[i].second);
return sqcnt;
}
// TODO: implement these methods
void idfs(int u) {
int sqidx = sqnode[u];
int v = u;
while(parent[v] > 0) {
int i = iepa[v];
if (color_data[sqidx].count(edge[i].c)) {
pair<int, int> pi = color_data[sqidx][edge[i].c];
pi.first++;
pi.second += edge[i].d;
color_data[sqidx][edge[i].c] = pi;
}
else {
color_data[sqidx][edge[i].c] = {1, edge[i].d};
}
v = parent[v];
if (sqnode[v]) {
for (auto it = color_data[sqnode[v]].begin(); it != color_data[sqnode[v]].end(); ++it) {
if (color_data[sqidx].count(it->first)) {
auto pi = color_data[sqidx][it->first];
pi.first += it->second.first;
pi.second += it->second.second;
color_data[sqidx][it->first] = pi;
}
else {
color_data[sqidx][it->first] = it->second;
}
}
break;
}
}
}
void distdfs(int u, int d) {
dist_to_root[u] = d;
for (auto i : eadj[u])
if (edge[i].v != parent[u])
distdfs(edge[i].v, d + edge[i].d);
}
// ====
void precompute() {
// TODO: init variables to store when precompute
// WARNING: use map instead of array to store values (to prevent Memore Exceed)
color_data = new map<int, pair<int, int> >[sqcnt+1];
dist_to_root = new int[n+1];
fill(dist_to_root, dist_to_root + n + 1, 0);
// ====
// TODO: computation depending of specific problem
for (int i = 1; i <= n; ++i) {
int u = depth_index[i].second;
if (sqnode[u])
idfs(u);
}
distdfs(root, 0);
// ====
}
// TODO: implement this method
int compute(int u, int c, int x) {
int curr_dist = dist_to_root[u];
int cnt = 0;
int sum = 0;
while(u != root && u > 0) {
if (sqnode[u]) {
if (color_data[sqnode[u]].count(c)) {
cnt += color_data[sqnode[u]][c].first;
sum += color_data[sqnode[u]][c].second;
}
break;
}
int i = iepa[u];
if (edge[i].c == c) {
cnt++;
sum += edge[i].d;
}
u = parent[u];
}
int res = curr_dist + (cnt*x - sum);
return res;
}
// ====
void lcadfs(int u, int pa) {
ancestor[u][0] = pa;
for (int j = 1; j <= lgn; ++j)
ancestor[u][j] = ancestor[ancestor[u][j-1]][j-1];
for (auto i: eadj[u])
if (edge[i].v != pa)
lcadfs(edge[i].v, u);
}
void prepareLCA() {
lgn = log2(n) + 1;
ancestor = new int*[n+1];
for (int i = 0; i <= n; ++i)
ancestor[i] = new int[lgn+1];
for (int j = 0; j <= lgn; ++j)
ancestor[0][j] = 0;
lcadfs(root, 0);
}
int getLCA(int u, int v) {
if (depth[u] < depth[v])
swap(u, v);
// balance depth of u and v
for (int j = lgn; j >= 0; --j)
if (depth[ancestor[u][j]] >= depth[v])
u = ancestor[u][j];
// if 1 of them is ancestor of the other, return result
if (u == v)
return u;
// find the highest non-common ancestors of u and v
for (int j = lgn; j >= 0; --j)
if (ancestor[u][j] != ancestor[v][j]) {
u = ancestor[u][j];
v = ancestor[v][j];
}
// return
return ancestor[u][0];
}
};