我需要展平权重和偏差字典以进行梯度检查,我创建了此功能来展平我的字典,该功能可以正常工作,但是我似乎找不到恢复这一过程的方法。
@staticmethod
def flatten_dic(dic):
keys = []
count = 0
theta = np.array([])
for i in dic.keys():
new_vector = np.reshape(dic[i], (-1, 1))
keys = keys + [i] * new_vector.shape[0]
if count == 0:
theta = new_vector
else:
theta = np.concatenate((theta, new_vector), axis=0)
count = count + 1
return theta, keys
输入
{"W1":[[1,2,3],[3,2,1]],"W2":[1,2,3]}
它输出
[1,2,3,3,2,1,1,2,3]
答案 0 :(得分:1)
r = {"W1":[[1,2,3],[3,2,1]],"W2":[1,2,3]}
result = []
def flatten(_list):
if type(_list[0]) == list:
for e in _list:
flatten(e)
else:
result.extend(_list)
[flatten(e) for e in r.values()]
这将为您提供一个扁平化的结果列表,这还将保留您的原始词典。