我一直在使用CNN进行文本分类,并使用了tensorflow的contrib learn。
但是,当我尝试执行以下代码时:
classifier = learn.Estimator(model_fn=cnn_model)
classifier.fit(x_train, y_train, steps=10000)
y_predicted = [ p['class'] for p in classifier.predict(x_test, as_iterable=True)]
score = metrics.accuracy_score(y_test, y_predicted)
print('Accuracy: {0:f}'.format(score))
我正在运行以下错误:
错误:attr'TI'的数据类型字符串不在允许值列表中: uint8,int32,int64'classifier.fit'
答案 0 :(得分:0)
您需要将输入y_train
转换为给定类型。
print(type(y_train))
最有可能是浮点而不是整数。