我正在使用Tensorflow,训练了一个广泛的网络,并希望预测一些值。我使用了像Tensorflow iris prediction example之类的网,但改变了预测部分
new_samples = np.array([[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = list(classifier.predict(new_samples, as_iterable=True))
可以在我自己的输入函数中从我的测试文件中读取数据:
y = list(classifier.predict(input_fn=lambda: input_fn(test_file_name, batch_size, batch_number)))
经过一些测试后,我发现预测顺序不是文件的数据顺序。如何强制Tensorflow在正确的修正中输出预测?作为另一种选择,如何使用功能(以及行的标签)打印出预测?
感谢您的支持。
答案 0 :(得分:1)
8个月之后回答这个问题,但万一其他人偶然发现并且有同样的问题 - 我怀疑问题是你使用了像
这样的输入功能def get_input_fn(data_set, num_epochs=None, shuffle=True):
return tf.estimator.inputs.pandas_input_fn(
x=pd.DataFrame(data_set[FEATURES]),
y=pd.Series(data_set[LABELS]),
num_epochs=num_epochs,
shuffle=shuffle, num_threads=1)
哪个好,但是当你运行predict()时你需要设置shuffle = False(否则它会改变你的输出!)