我正在尝试使用py_func从函数返回数据集,以在tensorflow数据集管道/ api中使用。但是,py_func会引发错误:
TypeError: Expected DataType for argument 'Tout' not <class 'tensorflow.python.data.ops.dataset_ops.Dataset'>.
一个最小的示例如下:
import tensorflow as tf
import numpy as np
def fn(x, y):
a = tf.data.Dataset_from_tensors((x, y))
b = tf.data.Dataset_from_tensors((x, y))
return a.concatenate(b)
if __name__ == "__main__":
features = np.random.rand(5, 5, 5, 1)
labels = np.random.rand(5, 5)
dataset = tf.data.Dataset.from_tensors((features, labels))
dataset = dataset.flat_map(
lambda feature, label: tuple(tf.py_func(
fn, [feature, label], [tf.data.Dataset])))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
val = sess.run(next_element)
这是tensorflow的错误吗,还是我使用api的方式不正确?谢谢!
答案 0 :(得分:0)
tf.py_func
使我们能够在不容易获得等效的Tensorflow API时包装python代码。
我认为它无法返回Tensorflow类。
作为示例,我复制了此代码,并将其发布到另一个线程中。此代码使用Python API格式化从文件读取的日期。返回的数据类型是Tensorflow数据类型。
tf.py_func
方便地在API doc.
import tensorflow as tf
from datetime import datetime
sess = tf.Session()
#Could be refactored
def convert_to_date(text):
date = datetime.strptime(text.decode('ascii'), '%b %d %Y %I:%M%p')
return date.strftime('%b %d %Y %I:%M%p')
filenames = ["C:/Machine Learning/text.txt"]
dataset = tf.data.Dataset.from_tensor_slices(filenames)
tf.data.TextLineDataset
dataset = dataset.flat_map(
lambda filename :
tf.data.TextLineDataset( filename ) ).map( lambda text :
tf.py_func(convert_to_date,
[text],
[tf.string]))
iterator = dataset.make_one_shot_iterator()
date = iterator.get_next()
print(sess.run([date]))