我正在尝试将一些代码转换为新的数据集API,以便可以使用分发策略。以下是我正在尝试做的事情。
def dataset_generator():
while True:
features, labels = ex_lib.get_image_batch(), ex_lib.get_feature_batch()
yield features, labels
def get_ssf_input_fn():
def input_fn():
return tf.data.Dataset.from_generator(dataset_generator,
(tf.float32, tf.float32), ([None, config.image_height, config.image_width, config.image_channels], [None, 256]))
return input_fn
问题是ex_lib.get_image_batch
和ex_lib.get_feature_batch
给了我一个张量而不是一个numpy数组,而且我无法更改ex_lib中的代码。另外,由于无法在此处访问sess
,因此无法在这里将张量转换为numpy数组。使用此代码,它将抛出
`generator` yielded an element that could not be converted to the expected type. The expected type was float32, but the yielded element was Tensor("GetImageBatch:0", dtype=uint8)
有没有办法让我的input_fn返回一个数据集?
答案 0 :(得分:0)
我可以使用以下技巧来解决此问题。效率还可以。
tf.data.Dataset.from_tensors(0).repeat().map(lambda _: dataset_generator())