我知道已经被问了几次了,但是没有答案符合我的要求。
我有一个带文本(报纸内容)和标签的csv文件,列0和1。
我试图编写第一个用于文本分类的自定义生成器,但出现错误
ValueError: `validation_data` should be a tuple `(val_x, val_y, val_sample_weight)` or `(val_x, val_y)`. Found: <__main__.Generator object at 0xd376a6e80>
这是课程
class Generator(object):
def __init__(self, data_file):
self.data_file = data_file
self.length = -1
def __iter__(self):
while True:
with open(self.data_file, 'r') as f:
reader = csv.reader(f)
for row in reader:
yield row[0], row[1]
def __len__(self):
if self.length == -1:
n_rows = 0
with open(self.data_file, 'r') as f:
reader = csv.reader(f)
for row in reader:
n_rows += 1
self.length = n_rows
return self.length
我也使用yield row[0], row[1]
和return
进行了尝试。都不起作用。
感谢您的帮助
答案 0 :(得分:1)
在使Generator类从keras.utils.Sequence继承方法之前,我一直遇到相同的错误(请参见fit_generator documentation)。您可以尝试以下方法:
import keras
class Generator(keras.utils.Sequence):
def __init__(self, data_file):
self.data_file = data_file
self.length = -1
def __iter__(self):
while True:
with open(self.data_file, 'r') as f:
reader = csv.reader(f)
for row in reader:
yield row[0], row[1]
def __len__(self):
if self.length == -1:
n_rows = 0
with open(self.data_file, 'r') as f:
reader = csv.reader(f)
for row in reader:
n_rows += 1
self.length = n_rows
return self.length