我有爱尔兰数据集的基本分类代码。
import tensorflow as tf
import pandas as pd
COLUMN_NAMES = [
'SepalLength',
'SepalWidth',
'PetalLength',
'PetalWidth',
'Species'
]
# Import training dataset
training_dataset = pd.read_csv('iris_training.csv', names=COLUMN_NAMES, header=0)
train_x = training_dataset.iloc[:, 0:4]
train_y = training_dataset.iloc[:, 4]
# Import testing dataset
test_dataset = pd.read_csv('iris_test.csv', names=COLUMN_NAMES, header=0)
test_x = test_dataset.iloc[:, 0:4]
test_y = test_dataset.iloc[:, 4]
columns_feat = [
tf.feature_column.numeric_column(key='SepalLength'),
tf.feature_column.numeric_column(key='SepalWidth'),
tf.feature_column.numeric_column(key='PetalLength'),
tf.feature_column.numeric_column(key='PetalWidth')
]
classifier = tf.estimator.DNNClassifier(
feature_columns=columns_feat,
# Two hidden layers of 10 nodes each.
hidden_units=[10, 10],
# The model is classifying 3 classes
n_classes=3)
def train_function(inputs, outputs, batch_size):
dataset = tf.data.Dataset.from_tensor_slices((dict(inputs), outputs))
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
# Train the Model.
classifier.train(
input_fn=lambda:train_function(train_x, train_y, 100),
steps=1000)
def evaluation_function(attributes, classes, batch_size):
attributes=dict(attributes)
if classes is None:
inputs = attributes
else:
inputs = (attributes, classes)
dataset = tf.data.Dataset.from_tensor_slices(inputs)
assert batch_size is not None, "batch_size must not be None"
dataset = dataset.batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
# Evaluate the model.
eval_result = classifier.evaluate(
input_fn=lambda:evaluation_function(test_x, test_y, 100))
我评估结果,但我如何对我的数据进行预测,因为现在我只获得了丢失和时期的控制台信息,准确性。例如,如果我有物种以外的一切。我想给自己的萼片长度等等所以我可以得到物种的预测,这将是另一个变量。我是否必须创建像pred_x或pred_y(pandas dataframe)这样的变量然后将它们放入eval_result?
答案 0 :(得分:1)
这是你的意思吗?例如:new_samples = np.array([[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=np.float32)
如果您希望这样的新数据进行预测,那么您可以参考此代码。TensorFlow-Iris-Classification
答案 1 :(得分:0)
与所有估算工具类一样,DNNClassifier
类具有predict
方法,可以进行实际预测。文档为here。