TF2 / Keras切片张量使用[:,:,0]

时间:2019-08-10 02:01:59

标签: python tensorflow keras deep-learning tensorflow2.0

在TF 2.0 Beta中,我正在尝试:

x = tf.keras.layers.Input(shape=(240, 2), dtype=tf.float32)
print(x.shape) # (None, 240, 2)
a = x[:, :, 0]
print(a.shape) # <unknown>

在TF 1.x中,我可以这样做:

x = tf1.placeholder(tf1.float32, (None, 240, 2)
a = x[:, :, 0]

,它将正常工作。如何在TF 2.0中实现这一目标?我认为

tf.split(x, 2, axis=2)

可能有效,但是我想使用切片而不是对2(轴2的暗角)进行硬编码。

1 个答案:

答案 0 :(得分:0)

区别在于从Input返回的对象代表一个图层,而不是类似于占位符或张量的任何东西。因此,上面tf 2.0代码中的x是一个图层对象,而tf 1.x代码中的x是张量的占位符。

您可以定义切片层以执行操作。可以使用现成的图层,但是对于像这样的简单切片,Lambda图层非常易于阅读,也许最接近您在tf 1.x中进行切片的方式。

类似这样的东西:

input_lyr = tf.keras.layers.Input(shape=(240, 2), dtype=tf.float32)
sliced_lyr = tf.keras.layers.Lambda(lambda x: x[:,:,0])

您可以像这样在您的keras模型中使用它:

model = tf.keras.models.Sequential([
    input_lyr,
    sliced_lyr,
    # ...
    # <other layers>
    # ...
])

当然,以上是特定于 keras 模型的。相反,如果您有张量而不是keras图层对象,那么切片的工作原理与以前完全相同。像这样:

my_tensor = tf.random.uniform((8,240,2))
sliced = my_tensor[:,:,0]

print(my_tensor.shape)
print(sliced.shape)

输出:

(8, 240, 2)
(8, 240)

符合预期