Keras-如何在不损害计算图的情况下删除无用的维?

时间:2019-03-28 08:26:39

标签: keras

在生成深度学习模型时,当前两个维度为“无”形状时,我使用K.squeeze函数压缩无用的维度。

import keras.backend as K
>>> K.int_shape(user_input_for_TD)
(None, None, 1, 32)
>>> K.int_shape(K.squeeze(user_input_for_TD, axis=-2))
(None, None, 32)

但是,这会产生以下错误,似乎K.squeeze函数损害了计算图,是否有任何解决方案可以避免此问题?也许该函数不支持无法计算梯度。

File "/home/sundong/anaconda3/envs/py36/lib/python3.6/site-packages/keras/engine/network.py", line 1325, in build_map
    node = layer._inbound_nodes[node_index]
AttributeError: 'NoneType' object has no attribute '_inbound_nodes'

下面的代码块是导致该错误的整个代码块。

user_embedding_layer = Embedding(
            input_dim=len(self.data.visit_embedding),
            output_dim=32,
            weights=[np.array(list(self.data.visit_embedding.values()))],
            input_length=1,
            trainable=False)
...
all_areas_lstm = LSTM(1024, return_sequences=True)(all_areas_rslt)   # (None, None, 1024)
user_input_for_TD = Lambda(lambda x: x[:, :, 0:1])(multiple_inputs)  # (None, None, 1) 
user_input_for_TD = TimeDistributed(user_embedding_layer)(user_input_for_TD) # (None, None, 1, 32) 
user_input_for_TD = K.squeeze(user_input_for_TD, axis=-2) # (None, None, 32) 
aggre_threeway_inputs = Concatenate()([user_input_for_TD, all_areas_lstm]) # should be (None, None, 1056) 
threeway_encoder = TimeDistributed(ThreeWay(output_dim=512))
three_way_rslt = threeway_encoder(aggre_threeway_inputs) # should be (None, None, 512) 
logits = Dense(365, activation='softmax')(three_way_rslt) # should be (None, None, 365)
self.model = keras.Model(inputs=multiple_inputs, outputs=logits)

通过删除下面两行(不使其通过嵌入层),代码可以正常工作。在这种情况下,aggre_threeway_inputs = Concatenate()([user_input_for_TD, all_areas_lstm])的尺寸为(None,None,1025)。

user_input_for_TD = TimeDistributed(user_embedding_layer)(user_input_for_TD)
user_input_for_TD = K.squeeze(user_input_for_TD, axis=-2)

1 个答案:

答案 0 :(得分:0)

我通过将Lambda层用于索引而不是K.squeeze函数来解决了这个问题。

from keras.layers import Lambda
>>> K.int_shape(user_input_for_TD)
(None, None, 1, 32)
>>> K.int_shape(Lambda(lambda x: x[:, :, 0, :])(user_input_for_TD))
(None, None, 32)