TypeError:attr'Tlabels'的DataType字符串不在允许值列表中:int32,int64

时间:2017-02-28 08:39:08

标签: python machine-learning google-cloud-platform

我需要帮助训练张量流模型。我刚开始使用Google Cloud Platform,并尝试使用自己的数据集训练自己的模型。

以下是我目前的代码

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np
import h5py

# Datasets
training_set1 = h5py.File('/Users/Fang/workspace/keras_project/for_google-ml/trainer/11_type.h5', 'r')
output_set1 = open("/Users/Fang/workspace/keras_project/for_google-ml/trainer/output_type.txt", "r")

X = training_set1['the_data'][:]
Y = output_set1.read().split(',')

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=2)]

classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                        hidden_units=[500, 100, 11],
                                        n_classes=3,
                                        model_dir="/tmp/the_model")

classifier.fit(X, Y, steps=2000)
accuracy_score = classifier.evaluate(X, Y)["accuracy"]
print("Accuracy: %.2f".format(accuracy_score))

在本地运行时我得到了这样的错误。

TypeError: DataType string for attr 'Tlabels' not in list of allowed values: int32, int64

我的training_set包含以hdf5格式保存为the_data数据集的以下数据。它有多个500个数字的数组。示例:

[[12.424, 384.742,...],
 [3492.293, 349,..,...],
 [...,...,...],
  ...
  ...
 [...]]

我的output_set是一个包含以下500个数据的文本文件

aaa,aaa,aaa,...,bbb,bbb,bbb,...ccc,ccc,...,kkk,kkk,kkk

感谢您的帮助。

0 个答案:

没有答案