n, m = [int(x) for x in input().split()]
a = [int(x) for x in input().split()]
b = [int(x) for x in input().split()]
a = list(sorted(a))
b = list(sorted(b))
def check(t):
ind = 0
for i in range(len(a)):
left = b[ind]
if left > a[i]:
left = a[i]
left = a[i] - left
right = max((t - left) // 2, t - left * 2)
if right < 0:
break
right = right + a[i]
while ind < len(b) and b[ind] <= right:
ind += 1
if ind >= len(b):
return True
return False
l = 0
r = 10**9
while r - l > 1:
m = (r + l) // 2
if check(m):
r = m
else:
l = m
print(r)