如何将tf.cond与批处理操作/队列运行程序结合使用

时间:2016-09-16 12:19:51

标签: python tensorflow

情况

我想培训一种特定的网络架构(GAN),在培训期间需要来自不同来源的输入。

一个输入源是从磁盘加载的示例。另一个来源是生成子网络创建示例。

要选择要向网络提供哪种输入,请使用tf.cond。有一点需要注意,has already been explainedtf.cond评估两个条件分支的输入,即使最终只使用其中一个。

足够的设置,这是一个最小的工作示例:

import numpy as np
import tensorflow as tf

BATCH_SIZE = 32

def load_input_data():
  # Normally this data would be read from disk
  data = tf.reshape(np.arange(10 * BATCH_SIZE, dtype=np.float32), shape=(10 * BATCH_SIZE, 1))
  return tf.train.batch([data], BATCH_SIZE, enqueue_many=True)

def generate_input_data():
  # Normally this data would be generated by a much bigger sub-network
  return tf.random_uniform(shape=[BATCH_SIZE, 1])

def main():
  # A bool to choose between loaded or generated inputs
  load_inputs_pred = tf.placeholder(dtype=tf.bool, shape=[])

  # Variant 1: Call "load_input_data" inside tf.cond
  data_batch = tf.cond(load_inputs_pred, load_input_data, generate_input_data)
  # Variant 2: Call "load_input_data" outside tf.cond
  #loaded_data = load_input_data()
  #data_batch = tf.cond(load_inputs_pred, lambda: loaded_data, generate_input_data)

  init_op = tf.initialize_all_variables()

  with tf.Session() as sess:
    sess.run(init_op)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    print(threads)

    # Get generated input data
    data_batch_values = sess.run(data_batch, feed_dict={load_inputs_pred: False})
    print(data_batch_values)

    # Get input data loaded from disk
    data_batch_values = sess.run(data_batch, feed_dict={load_inputs_pred: True})
    print(data_batch_values)

if __name__ == '__main__':
  main()

问题

Variant 1根本不起作用,因为队列运行程序线程似乎没有运行。 print(threads)输出类似[<Thread(Thread-1, stopped daemon 140165838264064)>, ...]的内容。

变体2可以正常工作,print(threads)输出类似[<Thread(Thread-1, started daemon 140361854863104)>, ...]的内容。但由于load_input_data()已在tf.cond之外调用,因此即使load_inputs_predFalse,也会从磁盘加载批量数据。

是否可以使Variant 1工作,以便输入数据仅在load_inputs_predTrue时加载,而不是每次调用session.run()时都会加载?

1 个答案:

答案 0 :(得分:0)

如果您在加载数据时使用队列并使用批量输入进行跟踪,那么这不应该是一个问题,因为您可以指定已加载或存储在队列中的最大数量。

input = tf.WholeFileReader(somefilelist) # or another way to load data
return tf.train.batch(input,batch_size=10,capacity=100) 

请点击此处了解更多详情: https://www.tensorflow.org/versions/r0.10/api_docs/python/io_ops.html#batch

还有一种替代方法可以完全跳过tf.cond。只需定义两个通过自动编码器和描述符跟踪数据的丢失,另一个通过鉴别器跟踪数据。

然后它就变成了调用

的问题
sess.run(auto_loss,feed_dict)

sess.run(real_img_loss,feed_dict)

通过这种方式,图表只会在遇到损失的情况下运行。如果需要更多解释,请告诉我。

最后,如果您使用预加载的数据,我认为要做一个变体工作,你需要做这样的事情。

https://www.tensorflow.org/versions/r0.10/how_tos/reading_data/index.html#preloaded-data

否则,我不确定问题是诚实的。