attr'TI'的DataType字符串不在允许值列表中:uint8,int32,int64

时间:2017-02-22 13:22:12

标签: tensorflow neural-network

我一直在使用CNN进行文本分类,并使用了tensorflow的contrib learn。

但是,当我尝试执行以下代码时:

classifier = learn.Estimator(model_fn=cnn_model)

classifier.fit(x_train, y_train, steps=10000) 
y_predicted = [ p['class'] for p in classifier.predict(x_test, as_iterable=True)] 

score = metrics.accuracy_score(y_test, y_predicted) 

print('Accuracy: {0:f}'.format(score))

我正在运行以下错误:

  

错误:attr'TI'的数据类型字符串不在允许值列表中:   uint8,int32,int64'classifier.fit'

1 个答案:

答案 0 :(得分:0)

您需要将输入y_train转换为给定类型。 print(type(y_train))最有可能是浮点而不是整数。