为什么这个Tensorflow程序有效?

时间:2017-05-02 19:49:58

标签: python-3.x tensorflow

我在Mac OS 10.12.4,Anaconda Python 3.5和Tensorflow 1.1上运行。 我拼凑了下面显示的可重现代码。 我已经定义了" my_model"有参数"功能"和"标签"。 我没有定义它们。 " my_model"调用函数时不带任何参数。 我的Spyder"变量"程序运行后窗口不显示它们。 我的问题是:这些变量在哪里定义?

查尔斯

from sklearn import metrics, cross_validation
from tensorflow.contrib import layers
from tensorflow.contrib import learn
from sklearn.preprocessing import LabelEncoder
import pandas as pd

# shut up the warnings
import warnings
warnings.filterwarnings('ignore')
import logging
logging.getLogger("tensorflow").setLevel(logging.ERROR)
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)

def my_model(features, labels):
    labels = tf.one_hot(labels, 3, 1, 0)
    features = layers.stack(features, layers.fully_connected, [10, 20, 10])
    prediction, loss = (learn.models.logistic_regression(features, labels))
    train_op = tf.contrib.layers.optimize_loss(
        loss,
        tf.contrib.framework.get_global_step(),
        optimizer='Adagrad',
        learning_rate=0.1)

    return {'class': tf.argmax(prediction, 1), 'prob': prediction}, loss, train_op

df = pd.read_csv("iris.csv")
df = df.sample(frac=1)  # shuffle all rows
print(df.head())
column_names = list(df.columns[:4])
X = df[column_names].as_matrix()
y = df['Species']
le = LabelEncoder()
le.fit(df["Species"])
y = le.transform(df["Species"])
x_train, x_test, y_train, y_test = cross_validation.train_test_split(
  X, y, test_size=0.2, random_state=35)

classifier = tf.contrib.learn.Estimator(model_fn = my_model)
classifier.fit(x_train, y_train, steps=1000)

y_predicted = [p['class'] for p in classifier.predict(x_test, as_iterable=True)]
score = metrics.accuracy_score(y_test, y_predicted)
print('Accuracy: {0:f}'.format(score))

1 个答案:

答案 0 :(得分:0)

您的代码中未调用

my_model。它是Estimator调用的回调函数,带有2个参数:features和labels。

对于fit()函数,它们实际上是x_trainy_train

正如the doc所说,“模型功能,采取特征和目标张量的张量或指标,并返回预测和损失张量。例如,”(特征,目标) - > (预测,损失)“

你可以看到在source code of Estimator中的第1125行调用了model_fn:

model_fn_results = self._model_fn(features, labels, **kwargs)