public class Main { static long[]w; static List<integer>[]children; public static boolean isTSN(long num) { long n = (long) Math.sqrt(num); return n * n == num; } public static int[]dfs(int cur) { if (children[cur] == null) return new int[]{0, 0}; int[]ret = new int[]{0, 0}; int bonus = 0; for (int child : children[cur]) { int[]childRes = dfs(child); boolean flag = isTSN(w[cur] * w[child]); bonus = Math.max(bonus, childRes[1] + (flag ? 2 : 0) - childRes[0]); ret[1] += childRes[0]; } ret[0] = ret[1] + bonus; return ret; } } </integer>