将tf.train.string_input_producer升级到tf.data.Dataset.from_tensor_slices会产生错误

时间:2019-07-05 20:58:36

标签: python tensorflow deprecated

如TensorFlow官方网站所述:

tf.train.string_input_producer is deprecated and will be removed in a future version

我试图用建议的方法tf.data.Dataset.from_tensor_slices()替换它。

但是升级后,出现以下错误:

TypeError: Tensors in list passed to 'input' of 'PyFunc' Op have types [<NOT CONVERTIBLE TO TENSOR>] that are invalid. Tensors: [<DatasetV1Adapter shapes: (), types: tf.string>]

代码如下:

with tf.device("/cpu:0"), tf.name_scope(scope):
    '''This is the correct but deprecated version'''
    input_ops['id'] = tf.train.string_input_producer(
       tf.convert_to_tensor(data_id), capacity=128
    ).dequeue(name='input_ids_dequeue')

    ''' The following replaced code creates an error

    input_ops['id'] = tf.data.Dataset.from_tensor_slices(
        tf.convert_to_tensor(data_id)
    ).shuffle(128)

    '''

    img, q, a = dataset.get_data(data_id[0])

    def load_fn(id):
        # image [n, n], q: [m], a: [l]
        img, q, a = dataset.get_data(id)
        return (id, img.astype(np.float32), q.astype(np.float32),
                a.astype(np.float32))

    input_ops['id'], input_ops['img'], input_ops['q'], input_ops['a'] = \
        tf.py_func(
            load_fn,
            inp=[input_ops['id']],
            Tout=[tf.string, tf.float32, tf.float32, tf.float32],
            name='func'
    )

此处data_id是一些 n 个数字的列表。有人可以帮我吗?

谢谢。

0 个答案:

没有答案