天天看點

用python繪制生日蛋糕_AcWing 168. 生日蛋糕 Python 暴力搜尋 注意剪枝

'''

DFS 優化政策

假設第M層是最下面的一層, 第一層是最上面的一層

1. 層數從M層到1層枚舉,先枚舉體積大的層數

2. 每一個圓柱體先枚舉半徑, 再枚舉高度,因為半徑是平方級的,減枝比較快, 半徑和高度都是降序枚舉

3. 第i層的半徑R和高度H的範圍:

i <= R(i) <= R(i+1)-1

假設1層到i-1層已經累計了體積V, 那pi * R(i) * R(i) * H(i) <= pi * (N - V)

可以推導出 R(i) <= sqrt(N-V)

i <= H(i) <= H(i+1) - 1

同樣的,根據pi * R(i) * R(i) * H(i) <= pi * (N - V)

可以推導出H(i) <= (N-V) / R(i) / R(i)

4. 先預處理,計算所有前i-1層的可能的側邊面積最小值和體積最小值

如果目前的累計體積加上上面所有層體積的最小和大于N,則回溯

如果目前累計面積加上上面所喲鄫的面積的最小和大于等于ans, 則回溯

5. 體積和表面積之間有關聯關系

S(i) 表示1到i的圓柱體的側面表面積和

V(i) 表示1到i的圓柱體的體積和

S(i) = 2R(1)*H(1) + 2R(2)*H(2) + ...... 2R(i)*H(i)

> (2/R(i+1)) * [ R(1)^2 * H(1) + R(2)^2 * H(2) ...... R(i)^2 * H(i) ]

= (2/R(i+1)) * V(i)

= (2/R(i+1)) * (N - M層到i-1層的體積累計和)

也就是說如果知道了M層到i-1層的體積累計和,就能得到一個i層到1層的側面面積總和的下界,目前累計

的表面積和加上這個下界,如果超過了ans, 則可以回溯

'''

import math

N = int(input())

M = int(input())

R = [0x7fffffff] * (M+2) # 每一層半徑

H = [0x7fffffff] * (M+2) # 每一層高度

min_a_sum = [0] * (M+1) # 從1層到某一層的最小側面面積累加和

min_v_sum = [0] * (M+1) # 從1層到某一層的最小體積的累加和

for i in range(1, M+1):

min_a_sum[i] = min_a_sum[i-1] + 2*i*i

min_v_sum[i] = min_v_sum[i-1] + i*i*i

ans = [0x7fffffff]

def dfs(cur_level, sum_a, sum_v):

#print(cur_level, sum_a, sum_v)

if cur_level == 0:

if sum_v == N:

ans[0] = min(ans[0], sum_a)

return

sum_a_lower_bound = (N - sum_v) * 2 / R[cur_level + 1]

if sum_a + sum_a_lower_bound > ans[0]:

return

r_upper_bound = min(R[cur_level+1]-1, int(math.sqrt(N-sum_v)))

for r in range(r_upper_bound, cur_level-1, -1):

h_upper_bound = min(H[cur_level+1]-1, int( (N-sum_v) / r / r ))

for h in range(h_upper_bound, cur_level-1, -1):

cur_a = 2*r*h

cur_v = r*r*h

if sum_a + cur_a + min_a_sum[cur_level-1] >= ans[0]:

continue

if sum_v + cur_v + min_v_sum[cur_level-1] > N:

continue

old_r, old_h = R[cur_level], H[cur_level]

R[cur_level], H[cur_level] = r, h

if cur_level != M:

dfs(cur_level-1, sum_a + cur_a, sum_v + cur_v)

else:

dfs(cur_level-1, sum_a + cur_a + r*r, sum_v + cur_v)

R[cur_level], H[cur_level] = old_r, old_h

dfs(M, 0, 0)

if ans[0] == 0x7fffffff:

ans[0] = 0

print(ans[0])