具有py_func的Tensorflow数据映射导致ValueError:无法采用未知等级的Shape的长度

时间:2018-12-10 15:54:59

标签: python tensorflow tensorflow-datasets

我的dataset.map回调中有一个py_func,它返回一个样本权重。参见以下示例:

def get_sample_weight(team, number):
    # Some code
    return np.array([0.5]).astype(np.float32) #return weight

def preprocess_tfrecord_example(example, sample_weights):
    image = tf.image.decode_png(example['detection/encoded'])
    team = example['detection/team_name']

    sample_weight = tf.py_func(get_sample_weight, [team, example['detection/number']], tf.float32)
    image = tf.image.resize_image_with_pad(image,56,108)

    number = tf.expand_dims(tf.cast(example['detection/number'], tf.float32), -1)
    return image/255., number, sample_weight

但是,当我创建数据集并调用model.fit(dataset,....)时,会给出

  

ValueError:无法采用未知等级的Shape的长度。

这是因为Tensorflow无法确定python函数返回的形状,请参见Output from TensorFlow `py_func` has unknown rank/shape

作为一种解决方法,我尝试使用sample_weight.set_shape(1)设置形状,然后在preprocess_tfrecord_example(..)中将其返回。

但这给了

  

ValueError:找到了一个形状为(?,1)的sample_weight数组。为了使用按时间逐步分配的样本权重,应在compile()中指定sample_weight_mode =“ temporal”。如果您只是想使用基于样本的权重,请确保您的sample_weight数组为1D。

如何设置正确的形状,以便在model.fit中使用的sample_weight数组为1D?

0 个答案:

没有答案