第一题 import java.util.*; public class Main { public static void main(String[] args) { Scanner scan = new Scanner(System.in); String[] strs = scan.nextLine().split(" "); int[] pids = new int[strs.length]; for (int i = 0; i < strs.length; i++) pids[i] = Integer.parseInt(strs[i]); strs = scan.nextLine().split(" "); int[] ppids = new int[strs.length]; for (int i = 0; i < strs.length; i++) ppids[i] = Integer.parseInt(strs[i]); int n = Integer.parseInt(scan.nextLine()); System.out.println(getNums(pids, ppids, n)); } static int getNums(int[] pids, int[] ppids, int n) { HashMap<Integer, List<Integer>> map = new LinkedHashMap<>(); boolean flag = false; for (int i = 0; i < ppids.length; i++) { if (pids[i] == n) flag = true; int ppid = ppids[i]; if (map.containsKey(ppid)) { List<Integer> tmp = map.get(ppid); tmp.add(pids[i]); map.replace(ppid, tmp); } else { List<Integer> tmp = new ArrayList<>(); tmp.add(pids[i]); map.put(ppid, tmp); } } if (!flag) return 0; if (!map.containsKey(n)) return 1; return getnum(map, n) + 1; } static int getnum(HashMap<Integer, List<Integer>> map, int n) { if (!map.containsKey(n)) return 0; List<Integer> list = map.get(n); int all = list.size(); for (int i = 0; i < list.size(); i++) all += getnum(map, list.get(i)); return all; } }