我可以从tf.train.string_input_producer()获取文件的后缀吗?

时间:2016-10-31 08:46:00

标签: tensorflow

没有一般功能可以让您自动识别TensorFlow中的图像是jpeg还是png。如果收到无效输入,代码将会中断。

我想填充一个包含大量文件名的string_input_producer(包括jpeg和png),然后在决定将其传递给decode_jpeg或decode_png之前评估后缀。

有人可以提供一种方法来做到这一点,而无需进行任何预处理吗?

编辑 @Allen

代码来说明我在做什么。

def inputs():
    filenames = get_filenames() # crawls directories for all jpeg and png files.
    filename_queue = tf.train.string_input_producer(filenames)
    image = read_image(filename_queue) # this function has to split between decode_jpeg and decode_png
    image = preprocess(image)
    ...shuffle_batch stuff...
    return batch

def train():
    input = inputs()
    predictions = inference(input)
    ...loss definition and standard stuff...
    sess = tf.Session()
    sess.run(tf.initialize_all_variables())
    sess.run([train_op])

这是我打算做的。我不认为可以做一个评估。

1 个答案:

答案 0 :(得分:1)

tf.cond看起来就像你正在寻找的那样(只需确保图像处理操作在fn1和fn2中定义,这样你就可以获得真正的条件执行),并结合tf.decode_raw来读取文件名的最后几个字节:

tf.decode_raw(string, tf.uint8)

结果是包含字符串中字节的整数向量,可以使用TensorFlow操作对其进行切片和比较。例如,要检查字符串是否以“.jpeg”结尾:

import tensorflow as tf

def is_jpeg(file_name_string):
    file_name_bytes = tf.decode_raw(file_name_string, tf.uint8)
    return tf.reduce_all(tf.equal(file_name_bytes[-5:],
                                  tf.decode_raw(".jpeg", tf.uint8)))

with tf.Session():
    print(is_jpeg(tf.convert_to_tensor("file1.png")).eval()) # false
    print(is_jpeg(tf.convert_to_tensor("file2.jpeg")).eval()) # true

要完成图像解码,请将生成的布尔Tensor作为谓词传递给cond():

decoded_image = tf.cond(is_jpeg(file_name),
                        lambda: read_and_decode_jpeg(file_name),
                        lambda: read_and_decode_png(file_name))