Tensorflow-使用字符串标签训练神经网络

时间:2019-09-05 18:49:21

标签: python tensorflow neural-network

对于一个大学项目,我必须使用Tensorflow为OCR任务实现神经网络。 训练数据集由两个文件 train-data.csv train-target.csv 组成。 在 train-data 文件中,每行都填充有16x8位图的位,在 train-target 文件中,每行都是字符[az],它是相应字符的标签火车数据中的一行。

我在标签数据集方面遇到了一些问题,我已经按照教程学习了MNIST数据集,但是这里的区别在于我有字符串标签,而不是一键编码的矢量。 在学习完本教程之后,我将尝试使用softmax函数和交叉熵。

# First y * tf.log(y_hat) computes the element-wise multiplication of the two resulting vectors

# Second, tf.reduce_sum( , reduction_indices=[1]) computes the sum along the second dimension (the first one are the examples)
# Finally, tf.reduce_mean() computes the mean over the first dimension, i.e. the examples
cross_entropy = tf.reduce_mean(-tf.reduce_sum(tf.strings.to_number(y) * tf.math.log(y_hat), reduction_indices=[1]))

train_step = tf.compat.v1.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

在上面的几行中,我使用了tf.strings.to_number(y)将char转换为数字值。

由于运行()方法不接受张量对象,因此在运行会话时此转换会引起问题。

for _ in range(1000):
    batch_xs, batch_ys = next_batch(100, raw_train_data, train_targets)
    sess.run(train_step, feed_dict={x: batch_xs, y: tf.strings.to_number(batch_ys.reshape((100,1)))})

如果我不将char转换为数值,则会出现此错误:

InvalidArgumentError: StringToNumberOp could not correctly convert string: e
 [[{{node StringToNumber}}]]

我试图弄清楚如何解决此问题或如何使用字符标签训练神经网络,这是我整天都在努力解决的问题。 有人知道如何解决这个问题吗?

1 个答案:

答案 0 :(得分:0)

最后我找到了错误。 因为我对机器学习还很陌生,所以我忘记了很多算法不能处理分类数据集。

解决方案是对目标标签执行一次热编码,并使用以下功能将此新数组提供给newtork:

desc_detail