ValueError:`validation_data`应该是一个元组`(val_x,val_y,val_sample_weight)`或`(val_x,val_y)`。找到:<__ main __。Generator对象,位于>

时间:2019-02-24 17:55:00

标签: python-3.x keras-2

我知道已经被问了几次了,但是没有答案符合我的要求。

我有一个带文本(报纸内容)和标签的csv文件,列0和1。

我试图编写第一个用于文本分类的自定义生成器,但出现错误

ValueError: `validation_data` should be a tuple `(val_x, val_y, val_sample_weight)` or `(val_x, val_y)`. Found: <__main__.Generator object at 0xd376a6e80>

这是课程

class Generator(object):

    def __init__(self, data_file):
        self.data_file = data_file
        self.length = -1

    def __iter__(self):
        while True:
            with open(self.data_file, 'r') as f:
                reader = csv.reader(f)
                for row in reader:
                    yield row[0], row[1]

    def __len__(self):
        if self.length ==  -1:
            n_rows = 0
            with open(self.data_file, 'r') as f:
                reader = csv.reader(f)
                for row in reader:
                    n_rows += 1
            self.length = n_rows
        return self.length

我也使用yield row[0], row[1]return进行了尝试。都不起作用。

感谢您的帮助

1 个答案:

答案 0 :(得分:1)

在使Generator类从keras.utils.Sequence继承方法之前,我一直遇到相同的错误(请参见fit_generator documentation)。您可以尝试以下方法:

import keras

class Generator(keras.utils.Sequence):
    def __init__(self, data_file):
        self.data_file = data_file
        self.length = -1

    def __iter__(self):
        while True:
            with open(self.data_file, 'r') as f:
                reader = csv.reader(f)
                for row in reader:
                    yield row[0], row[1]

    def __len__(self):
        if self.length ==  -1:
            n_rows = 0
            with open(self.data_file, 'r') as f:
                reader = csv.reader(f)
                for row in reader:
                    n_rows += 1
            self.length = n_rows
        return self.length