我刚刚使用pip安装了张量流,我尝试运行以下教程:
https://www.tensorflow.org/versions/r0.11/tutorials/tflearn/index.html
# Data sets
IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"
# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TRAINING,
target_dtype=np.int)
test_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TEST,
target_dtype=np.int)
但我有错误:
training_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TRAINING,
AttributeError: 'module' object has no attribute 'load_csv'
我读了一些答案,说我需要使用pandas数据帧?但是,不应该像教程一样工作吗?那太奇怪了!我不应该是唯一面对这个问题的人吗?
这里是整个代码,如教程:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
# Data sets
IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"
# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TRAINING,
target_dtype=np.int)
test_set = tf.contrib.learn.datasets.base.load_csv(filename=IRIS_TEST,
target_dtype=np.int)
# Specify that all features have real-value data
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=3,
model_dir="/tmp/iris_model")
# Fit model.
classifier.fit(x=training_set.data,
y=training_set.target,
steps=2000)
# Evaluate accuracy.
accuracy_score = classifier.evaluate(x=test_set.data,
y=test_set.target)["accuracy"]
print('Accuracy: {0:f}'.format(accuracy_score))
# Classify two new flower samples.
new_samples = np.array(
[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = classifier.predict(new_samples)
print('Predictions: {}'.format(str(y)))
答案 0 :(得分:1)
在TensorFlow版本0.11中删除了函数tf.contrib.learn.datasets.base.load_csv()
。根据文件是否有标题(并且Iris数据集确实有标题),替换函数是:
答案 1 :(得分:0)
因为我的版本是11 ...他们删除了11中的load_csv而没有更改教程...我必须运行版本0.10.0rc0才能运行教程。
答案 2 :(得分:0)
库urllib已经删除了方法urlopen(),你应该导入urllib.request,然后使用方法urllib.request.urlopen()。