如何使用Tensorflow功能API沿批量维度广播?

时间:2020-08-02 00:14:53

标签: python tensorflow

在某些应用中,例如插槽注意(在Pytorch spaCy's entity types中实现),必须沿批处理维度进行广播。但是,我看不到如何使用功能性API。例如,

import tensorflow as tf
const = tf.ones((1,4))
input = tf.keras.layers.Input((4))

const = tf.broadcast_to(const, input.shape)

引发以下错误:

ValueError: Cannot convert a partially known TensorShape to a Tensor: (None, 4)

因此,我求助于tf.keras.Model的子类化,但是我想将我的代码保留在功能性API中。有人知道如何做到这一点吗?

1 个答案:

答案 0 :(得分:0)

最后使用tf.keras.backend.shape找到了答案:

const = tf.ones((1,4))
input = tf.keras.layers.Input((4))

const = tf.broadcast_to(const, [tf.keras.backend.shape(input)[0], 4] )

# Shape of const is now (None, 4)