s = input().strip()
t = input().strip()
n = len(s)
m = len(t)
if n < m:
print(0)
exit(0)
B = []
for _ in range(26):
B += [[0] * n]
def get(l, r, ch):
if l == 0:
return B[ch][r]
return B[ch][r] - B[ch][l - 1]
for ch in range(0, 26):
if ord(s[0]) - ord('a') == ch:
B[ch][0] = 1
for i in range(1, n):
B[ch][i] = B[ch][i-1]
if ord(s[i]) - ord('a') == ch:
B[ch][i] += 1
ans = 0
for i in range(m):
ans += (i + 1) * get(i, n - m + i, ord(t[i]) - ord('a'))
print(ans)