uu帮忙看看最后一题,
from collections import defaultdict
nodes = defaultdict(list)
n, a, b, m = map(int, input().split())
roots = list(map(int, input().split()))
for i, x in enumerate(roots):
nodes[x].append(i+2)
def dfs(i):
if len(nodes[i]) == 0:
return 0
t = 0
for nx in nodes[i]:
t += (dfs(nx) % m + a**i % m) * b % m
return t % m
print(dfs(1))