张量流中的输入管道

时间:2017-01-14 14:25:42

标签: python input tensorflow jpeg

在初学者的官方tensorflow mnist数据集教程中练习时,我正在尝试将mnist数据更改为我自己从搜索引擎收集的图像。

strFilePaths,iLabels ,strSubFolderNames,iNumTotalDatasets = ScanForImage('Datasets')

tsFileNameQueue = tf.train.string_input_producer(strFilePaths)
tsReader = tf.WholeFileReader()
_,tsImage = tsReader.read(tsFileNameQueue)

tsImage = tf.image.decode_jpeg(tsImage, channels=3)
tsImage = tf.cast(tsImage,tf.float32)
tsLabels = tf.convert_to_tensor(iLabels, dtype=tf.float32)
tsImage = tf.reshape(tsImage, shape=[1,168*300*3])

matWeights = tf.Variable(tf.random_normal([168*300*3, 2]))
vBiases = tf.Variable(tf.zeros([2]))
vPredictions = tf.nn.softmax(tf.matmul(tsImage, matWeights) + vBiases)
fCrossEntropy = tf.reduce_mean(-tf.reduce_sum(tsLabels * tf.log(vPredictions), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(fCrossEntropy)
init = tf.global_variables_initializer()

with tf.Session() as sess : 
    sess.run(init)
    for i in range (1000) : 
    tsTrainingSets = tf.train.batch([tsImage,tsLabels], batch_size=100)
    sess.run(train_step)
        if i % 20 == 0 : 
            correct_prediction = tf.equal(tf.argmax(vPredictions,1),tf.argmax(tsTrainingSets[1],1))
            accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
            print(sess.run(accuracy))

strFilePaths是包含我所有图像路径的标准python列表,iLabels是表示标签的列表列表。在这种情况下我只有两个班级。

这个程序在没有错误输出的情况下运行,但是tensorflow只是继续运行而不给我任何输出。我已经在tensorflow网站上阅读过几千次的“阅读文件”会议,但我仍然不清楚我是否做得对不对。

Q1:这段代码出了什么问题? Q2:有没有关于如何将jpeg文件读入tensorflow并对它们执行一些训练任务的完整示例?

1 个答案:

答案 0 :(得分:0)

不幸的是,如果无法访问您的代码和文件,我无法帮助您进一步调试。但是,您可以在image_retraining示例中查看如何重新启动Inception以识别新类别信息(例如鲜花)的完整示例:https://github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/examples/image_retraining/retrain.py