如何让发电机可以调用?

时间:2018-03-14 14:15:15

标签: python

我尝试从长度为784位的CSV文件创建数据集。这是我的代码:

import tensorflow as tf

f = open("test.csv", "r")
csvreader = csv.reader(f)
gen = (row for row in csvreader)
ds = tf.data.Dataset()
ds.from_generator(gen, [tf.uint8]*28**2)

我收到以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-22-4b244ea66c1d> in <module>()
     12 gen = (row for row in csvreader_pat_trn)
     13 ds = tf.data.Dataset()
---> 14 ds.from_generator(gen, [tf.uint8]*28**2)

~/Documents/Programming/ANN/labs/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py in from_generator(generator, output_types, output_shapes)
    317     """
    318     if not callable(generator):
--> 319       raise TypeError("`generator` must be callable.")
    320     if output_shapes is None:
    321       output_shapes = nest.map_structure(

TypeError: `generator` must be callable.

docs说我应该将一个生成器传递给from_generator(),这样我做了什么,gen是一个生成器。但现在它抱怨我的发电机不能可赎回。如何让生成器可调用,以便我可以使用它?

修改 我想补充一点,我使用的是python 3.6.4。这是错误的原因吗?

3 个答案:

答案 0 :(得分:4)

generator参数(可能令人困惑)实际上不应该是生成器,而是可调用的返回可迭代的(例如,生成器函数)。这里最简单的选择可能是使用lambda。此外,还有一些错误:1)tf.data.Dataset.from_generator意味着被称为类工厂方法,而不是来自实例2)该函数(就像TensorFlow中的其他一些)对参数非常挑剔,它想要您要将dtypes和每个数据行的序列设为tuple s(而不是CSV阅读器返回的list),您可以使用例如map

import csv
import tensorflow as tf

with open("test.csv", "r") as f:
    csvreader = csv.reader(f)
    ds = tf.data.Dataset.from_generator(lambda: map(tuple, csvreader),
                                        (tf.uint8,) * (28 ** 2))

答案 1 :(得分:2)

好吧,两年后...但是,嘿!另一个解决方案! :D

这可能不是最干净的答案,但是对于更复杂的生成器,可以使用装饰器。我制作了一个生成两个字典的生成器,例如:

>>> train,val = dataloader("path/to/dataset")
>>> x,y = next(train)
>>> print(x)
{"data": [...], "filename": "image.png"}

>>> print(y)
{"category": "Dog", "category_id": 1, "background": "park"}

当我尝试使用from_generator时,它给了我错误:

>>> ds_tf = tf.data.Dataset.from_generator(
    iter(mm),
    ({"data":tf.float32, "filename":tf.string},
    {"category":tf.string, "category_id":tf.int32, "background":tf.string})
    )
TypeError: `generator` must be callable.

但是后来我写了一个装饰函数

>>> def make_gen_callable(_gen):
        def gen():
            for x,y in _gen:
                 yield x,y
        return gen
>>> train_ = make_gen_callable(train)
>>> train_ds = tf.data.Dataset.from_generator(
    train_,
    ({"data":tf.float32, "filename":tf.string},
    {"category":tf.string, "category_id":tf.int32, "background":tf.string})
    )

>>> for x,y in train_ds:
        break

>>> print(x)
{'data': <tf.Tensor: shape=(320, 480), dtype=float32, ... >,
 'filename': <tf.Tensor: shape=(), dtype=string, ...> 
}

>>> print(y)
{'category': <tf.Tensor: shape=(), dtype=string, numpy=b'Dog'>,
 'category_id': <tf.Tensor: shape=(), dtype=int32, numpy=1>,
 'background': <tf.Tensor: shape=(), dtype=string, numpy=b'Living Room'>
}

但是现在,请注意,要迭代train_,必须将其命名为

>>> for x,y in train_():
        do_stuff(x,y)
        ...

答案 2 :(得分:1)

您链接的

From the docs

  

generator参数必须是返回的可调用对象   支持iter()协议的对象(例如生成器函数)

这意味着您应该能够做到这样的事情:

import tensorflow as tf
import csv

with open("test.csv", "r") as f:
    csvreader = csv.reader(f)
    gen = lambda: (row for row in csvreader)
    ds = tf.data.Dataset()
    ds.from_generator(gen, [tf.uint8]*28**2)

换句话说,您传递的函数必须在调用时生成一个生成器。当使它成为匿名函数(lambda)时,这很容易实现。

或者试试这个,这更接近于在文档中完成的方式:

import tensorflow as tf
import csv


def read_csv(file_name="test.csv"):
    with open(file_name) as f:
        reader = csv.reader(f)
        for row in reader:
            yield row

ds = tf.data.Dataset.from_generator(read_csv, [tf.uint8]*28**2)

(如果您需要的文件名与您设置的默认名称不同,则可以使用functools.partial(read_csv, file_name="whatever.csv")。)

不同之处在于read_csv函数在调用时返回生成器对象,而您构造的内容已经是生成器对象并等效于:

gen = read_csv()
ds = tf.data.Dataset.from_generator(gen, [tf.uint8]*28**2)  # does not work