在Tensorflow中阅读批量图像和培训

时间:2018-04-19 12:13:18

标签: tensorflow

我试图在这里批量阅读图像。并训练他们。当我执行这个 它似乎陷入困境。没有进展。 谁能发现问题?没有写入日志。

def train():
filenames = tf.train.string_input_producer(
    tf.train.match_filenames_once("D:/*.png"))
reader = tf.WholeFileReader()
_, input = reader.read(filenames)
input = tf.Print(input,[input,tf.shape(input),"Input shape"])
input_image = tf.image.decode_png(input, channels=3)
input_image.set_shape([299, 299, 3])

batch = tf.train.batch([input_image],
                       batch_size=5,
                       allow_smaller_final_batch=True,
                       shapes=None,
                       num_threads = 1,
                       capacity = 32,
                       enqueue_many = False,
                       dynamic_pad = False)

init = (tf.global_variables_initializer(), tf.local_variables_initializer())

with tf.Session() as sess:
    sess.run(init)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    train_writer = tf.summary.FileWriter('D:/TensorFlow/logs/1/train', sess.graph)
    tf.print(input_image)

    for it in range(2):
        merge = tf.summary.merge_all()
        summary,_, X_batch =  sess.run([merge,input_image,batch])
        writer = train_writer.add_summary(summary)
        _, DiscriminatorLoss = sess.run([D_optimizer, Disc_loss], feed_dict={X: X_batch, Z: samplefromuniformdistribution(5, 100)})
        print (DiscriminatorLoss)
        _, GeneratorLoss = sess.run([G_optimizer, Generate_loss], feed_dict={Z: samplefromuniformdistribution(5, 100)})

    writer.flush()
    writer.close()

coord.request_stop()
coord.join(threads)
sess.close()

1 个答案:

答案 0 :(得分:0)

在我收到其他答案之前,我正在考虑这个答案。 我发现的问题与

有关
    filenames = tf.train.string_input_producer(
    tf.train.match_filenames_once("D:/Development_Avecto/TensorFlow/resizedimages/*.png"))

有问题的代码有这样的模式。

   filenames = tf.train.string_input_producer(
   tf.train.match_filenames_once("D:/*.png"))

D:/有很多子文件夹。代码是否试图搜索所有子文件夹?这需要很长时间才能完成。 检查了源代码(https://github.com/tensorflow/tensorflow/blob/r1.7/tensorflow/python/ops/io_ops.py),但没有识别读取操作。

即使对我来说不明显,也会搜索子文件夹。测试代码基于此我得出了这个结论。

def readtest():
filenames = tf.train.string_input_producer(
    tf.train.match_filenames_once("D:/FolderWithSubFolders/*.png"))
reader = tf.WholeFileReader()
key, input = reader.read(filenames)

init = (tf.global_variables_initializer(), tf.local_variables_initializer())

with tf.Session() as sess:
    sess.run(init)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    file, document_tensor = sess.run([key,input])
    print (file)
coord.request_stop()
coord.join(threads)
sess.close()