我希望在keras中实现PointNet(https://arxiv.org/pdf/1612.00593.pdf)的变体,但是我很难重复上下文向量(g)一段时间,以便我可以将它与前一个行连接起来缺乏上下文的层(前)。我尝试了Repeat()和keras.backend.Tile()。
input = Input(shape=(None,3))
x = TimeDistributed(Dense(128, activation = 'relu'))(input)
pre = TimeDistributed(Dense(256, activation = 'relu'))(x)
g = GlobalMaxPooling1D()(pre)
x = Lambda(merge_on_single, output_shape=(None,512))([pre,g])
print(x.shape)
这是我提出的lambda定义。
def merge_on_single(v):
#v[0] is variable length tensor, v[1] is the single vector
return Concatenate()([K.repeat(v[1],K.get_variable_shape(v[0])),v[0]])
但是会发生以下错误:
TypeError:传递给'Pack'Op'值'的列表中的张量具有并非全部匹配的类型[int32,,int32]。
更新:
所以我可以通过执行以下操作让图层不会出错:
input = Input(shape=(None,3))
num_point = K.placeholder(input.get_shape()[1].value, dtype=tf.int32)
#first global feature layer
x = TimeDistributed(Dense(512, activation = 'relu'))(input)
x = TimeDistributed(Dense(256, activation = 'relu'))(x)
g = GlobalMaxPooling1D()(x)
g = K.reshape(g,(-1,1,256))
g = K.tile(x, [1,num_point,1])
concat_feat = K.concatenate([x, g])
但现在,我收到以下错误:
AttributeError: 'Tensor' object has no attribute '_keras_history'
答案 0 :(得分:0)
我怀疑罪魁祸首是K.get_variable_shape(v[0])
。由于v[0]
的类型为int32
(由您的错误指定),因此当您获得该形状时,它将返回None。 Concatenate希望所有输入都是相同的类型。