试图让这个tensorflow脚本工作

时间:2018-02-19 22:49:43

标签: python tensorflow

我试图将我的scikit-learn python脚本移动到tensorflow代码中。继续陷入错误。请帮忙!

    import pandas as pd
    import numpy as np
    import tensorflow as tf

    # read csv
    df = pd.read_csv("/Downloads/iris-2.csv", header=0)

    # get header names as array
    features = list(df.columns.values)
    label = features.pop()
    classes = len(df[label].unique())

    # encode target
    X = df[features]
    y = df[label]

    # convert feature headers into tf
    for index,value in enumerate(features):
        features[index] = tf.feature_column.numeric_column(value)

    # initialize classifier
    classifier = tf.estimator.DNNClassifier(
        feature_columns=features,
        hidden_units=[10, 10],
        n_classes=classes)

    # train the classifier
    dataset = tf.data.Dataset.from_tensor_slices((dict(X), y))
    dataset = dataset.shuffle(1000).repeat().batch(0)
    data = dataset.make_one_shot_iterator().get_next()
    classifier.train(input_fn=lambda:data,steps=3)
    predictions = classifier.predict([5.1,3.0,4.2,1.2])
    print(predictions)

我坚持的最新错误是:

ValueError: Passed Tensor("dnn/head/weighted_loss/Sum:0", shape=(), dtype=float32) should have graph attribute that is equal to current graph <tensorflow.python.framework.ops.Graph object at 0x10dd9a190>.

这是我使用的数据集:https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/d546eaee765268bf2f487608c537c05e22e4b221/iris.csv

1 个答案:

答案 0 :(得分:0)

无法预先计算输入张量(变量数据数据集)。它们需要在 train 调用中传递给 input_fn 的函数内计算,以便张量在Estimator(分类器的图形中)在调用 train()期间创建。所以对于你的最后一个块你可以使用:

# train the classifier
def my_input_fn():
    dataset = tf.data.Dataset.from_tensor_slices((dict(X), y))
    dataset = dataset.shuffle(1000).repeat().batch(0)
    return dataset.make_one_shot_iterator().get_next()
classifier.train(input_fn=my_input_fn, steps=3)
predictions = classifier.predict([5.1,3.0,4.2,1.2])
print(predictions)