在session.run()中评估tf.py_func会返回“无效的类型列表”

时间:2016-06-24 11:48:46

标签: python tensorflow

我的数据按结构排列:

data/
  good/
  bad/

每个子文件夹包含jpg文件。 我正在尝试创建一个输入管道,接受basepath作为输入并构建image_oplabel_op。评估这些会给我一个(图像,标签)元组,例如:

image, label = session.run([image_op, label_op])

要获得给定图像的标签,我必须查看其图像路径。一个简单的解决方案是:

label = int("good" in path)

在tensorflow中我没有支持这样的字符串操作(v0.9),所以我想在上面这个简单的函数中使用tf.py_func包装器。但是,在评估标签op成功的同时,在尝试评估使用图像路径op作为同一session.run()输入的图像路径op和标签op时,我会收到错误。

这是我的python函数和tf图代码:

def get_label(path):
    return int("good" in str(path))

class DataReadingGraph:
    """
    Graph for reading images into a data queue.
    """

    def __init__(self, base_path):
        """
        Construct the queue
        :param base_path: path to base directory that contains good images and bad images
        subdirectories. They in turn can contain further subdirectories.
        :return:
        """

        # tf can't handle recursive files matching (as of version 0.9), so
        # solve that with glob and just pass globbed paths to a constant
        pattern = os.path.join(base_path, "**/*.jpg")
        filenames = tf.constant(glob.glob(pattern, recursive=True))

        filename_queue = tf.train.string_input_producer(filenames, shuffle=True)
        reader = tf.IdentityReader()
        self.key, self.value = reader.read(filename_queue)

        self.label = tf.py_func(get_label, [self.key], [tf.int64])

现在如果我跑

label = session.run(data_reading_graph.label)

一切都很好,我按预期得到了标签。但如果我跑

key, label = session.run([data_reading_graph.key, data_reading_graph.label])

相反,我得到了

<class 'TypeError'>
Fetch argument [<tf.Tensor 'PyFunc:0' shape=<unknown> dtype=int64>] of [<tf.Tensor 'PyFunc:0' shape=<unknown> dtype=int64>] has invalid type <class 'list'>, must be a string or Tensor. (Can not convert a list into a Tensor or Operation.)

我真的不明白这里出了什么问题,虽然它不应该转换成列表,但为什么我不能在同一session.run()中评估关键操作和标签操作。

我可以尝试在开始tf图之前在纯Python代码中进行标签提取,但问题仍然存在 - 为什么我不能在同一py_func

中评估session.run()及其输入

1 个答案:

答案 0 :(得分:1)

`select * from VM_REPORT_TEMP_US where to_date(START_DATE,'YYYY-MM-DD HH24:MI:SS') >= '2016-02-01 00:00:00' AND to_date(END_DATE,'YYYY-MM-DD HH24:MI:SS') <= '2016-`02-24 24:59:00'` 会返回张量的列表

你应该得到tf.py_func

self.label = self.label[0]