我有一段如下的训练代码-我不明白如何使用数据集将数据作为简单的numpy数组而不是张量来处理数据,因为在下面的代码中出现以下错误:
未启用急切执行时,张量对象不可迭代。 要遍历此张量,请使用tf.map_fn
基本上不能修改我的预处理功能tokenizer.texts_to_sequences
-因此要在批处理期间重用它们。
有人可以帮忙吗?
estimator_model = keras.estimator.model_to_estimator(keras_model=model, model_dir=outputdir)
def batchGen(tuples):
features = tuples[0]
print("All Raw Features -- "+str(features))
labels = tuples[1]
print("All Raw Labels -- "+str(labels))
num = 0
while num<len(features):
yield (features[num], labels[num])
num += 1
def train_fn_custom(features, labels, batch_size):
genSet = lambda:batchGen((features, labels))
dataset = tf.data.Dataset.from_generator(genSet, output_types= tp.float32)
def _preprocess_function(tuples):
features = tuples[0]
labels = tuples[1]
features=tokenizer.texts_to_sequences(features, maxlen)
features=sequence.pad_sequences(features, maxlen=maxlen)
features=features.astype(np.float32)
labels=utils.to_categorical(labels, nb_classes)
labels=labels.astype(np.float32)
return features, labels
dataset = dataset.map(_preprocess_function)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(5) # E P O C H S
return dataset
batch_size = 128
customTrain = lambda: train_fn_custom(x_train, y_train, batch_size)
estimator_model.train(input_fn=customTrain, steps=100)