tensorflow Dataset.from_generator使用产生张量的生成器

时间:2018-12-24 06:38:35

标签: tensorflow tensorflow-datasets

我正在尝试将一些代码转换为新的数据集API,以便可以使用分发策略。以下是我正在尝试做的事情。

def dataset_generator():
    while True:
        features, labels = ex_lib.get_image_batch(), ex_lib.get_feature_batch()
        yield features, labels

def get_ssf_input_fn():
    def input_fn():
        return tf.data.Dataset.from_generator(dataset_generator,
                                              (tf.float32, tf.float32), ([None, config.image_height, config.image_width, config.image_channels], [None, 256]))

    return input_fn

问题是ex_lib.get_image_batchex_lib.get_feature_batch给了我一个张量而不是一个numpy数组,而且我无法更改ex_lib中的代码。另外,由于无法在此处访问sess,因此无法在这里将张量转换为numpy数组。使用此代码,它将抛出

`generator` yielded an element that could not be converted to the expected type. The expected type was float32, but the yielded element was Tensor("GetImageBatch:0", dtype=uint8)

有没有办法让我的input_fn返回一个数据集?

1 个答案:

答案 0 :(得分:0)

我可以使用以下技巧来解决此问题。效率还可以。

tf.data.Dataset.from_tensors(0).repeat().map(lambda _: dataset_generator())