我在TF估算器上遇到了一个奇怪的问题,试图在我的输入函数中使用tf.Dataset。
首先,我的模型如下:
model = tf.estimator.DNNClassifier(
feature_columns=my_feature_column,
hidden_units=[hidden_layers, hidden_layers],
n_classes=n_classes)
我的特色专栏就是这样
my_feature_column = [tf.feature_column.numeric_column(key='image', shape=[32, 32, 3])]
现在,如果我像这样进行训练,那么一切都很好,并且训练只需几秒钟即可完成:
model.train(
input_fn=tf.estimator.inputs.numpy_input_fn(
dict({'image':X_train}),
y_train,
shuffle=True),
steps=nb_epoch)
但是当我尝试在输入函数中添加tf.Datasets时,它将永远需要运行:
def input_fn(features, labels, batch_size):
dataset = tf.data.Dataset.from_tensor_slices(({'image':features}, labels))
return dataset.shuffle(1000).batch(batch_size).repeat()
model.train(
input_fn=lambda:input_fn(X_train, y_train, batch_size),
steps=nb_epoch)
任何人都可以看到我在做什么错吗?应该完全一样吧?
谢谢, 保罗
答案 0 :(得分:0)
您的数据集将无限重复,并且没有默认的最大迭代次数,因此tensorflow不知道何时停止。
将return dataset.shuffle(1000).batch(batch_size).repeat()
所在的行替换为return dataset.shuffle(1000).batch(batch_size).repeat(10)
之类的东西,它将训练10个纪元,你会没事的。