如TensorFlow官方网站所述:
tf.train.string_input_producer is deprecated and will be removed in a future version
,
我试图用建议的方法tf.data.Dataset.from_tensor_slices()
替换它。
但是升级后,出现以下错误:
TypeError: Tensors in list passed to 'input' of 'PyFunc' Op have types [<NOT CONVERTIBLE TO TENSOR>] that are invalid. Tensors: [<DatasetV1Adapter shapes: (), types: tf.string>]
代码如下:
with tf.device("/cpu:0"), tf.name_scope(scope):
'''This is the correct but deprecated version'''
input_ops['id'] = tf.train.string_input_producer(
tf.convert_to_tensor(data_id), capacity=128
).dequeue(name='input_ids_dequeue')
''' The following replaced code creates an error
input_ops['id'] = tf.data.Dataset.from_tensor_slices(
tf.convert_to_tensor(data_id)
).shuffle(128)
'''
img, q, a = dataset.get_data(data_id[0])
def load_fn(id):
# image [n, n], q: [m], a: [l]
img, q, a = dataset.get_data(id)
return (id, img.astype(np.float32), q.astype(np.float32),
a.astype(np.float32))
input_ops['id'], input_ops['img'], input_ops['q'], input_ops['a'] = \
tf.py_func(
load_fn,
inp=[input_ops['id']],
Tout=[tf.string, tf.float32, tf.float32, tf.float32],
name='func'
)
此处data_id是一些 n 个数字的列表。有人可以帮我吗?
谢谢。