import sys
sys.setrecursionlimit(100000)
n, q = map(int, input().split())
g = [[] for i in range(n)]
parent = [0 for i in range(n)]
son = [0 for i in range(n)]
captured = [1 for i in range(n)]
def dfs(v, par = -1):
parent[v] = par
son[v] = len(g[v])
if par != -1:
son[v] -= 1
for to in g[v]:
if to != par:
dfs(to, v)
for i in range(n - 1):
a, b = map(int, input().split())
a -= 1
b -= 1
g[a].append(b)
g[b].append(a)
dfs(0)
ans = 1
for i in range(q):
a = int(input())
a -= 1
ans += (son[a] + (captured[parent[a]] if parent[a] > -1 else 0) - 1) * (1 if captured[a] else -1)
captured[a] ^= 1
if parent[a] > -1:
son[parent[a]] += 1 if captured[a] else -1
print(ans)