我想培训一种特定的网络架构(GAN),在培训期间需要来自不同来源的输入。
一个输入源是从磁盘加载的示例。另一个来源是生成子网络创建示例。
要选择要向网络提供哪种输入,请使用tf.cond
。有一点需要注意,has already been explained:tf.cond
评估两个条件分支的输入,即使最终只使用其中一个。
足够的设置,这是一个最小的工作示例:
import numpy as np
import tensorflow as tf
BATCH_SIZE = 32
def load_input_data():
# Normally this data would be read from disk
data = tf.reshape(np.arange(10 * BATCH_SIZE, dtype=np.float32), shape=(10 * BATCH_SIZE, 1))
return tf.train.batch([data], BATCH_SIZE, enqueue_many=True)
def generate_input_data():
# Normally this data would be generated by a much bigger sub-network
return tf.random_uniform(shape=[BATCH_SIZE, 1])
def main():
# A bool to choose between loaded or generated inputs
load_inputs_pred = tf.placeholder(dtype=tf.bool, shape=[])
# Variant 1: Call "load_input_data" inside tf.cond
data_batch = tf.cond(load_inputs_pred, load_input_data, generate_input_data)
# Variant 2: Call "load_input_data" outside tf.cond
#loaded_data = load_input_data()
#data_batch = tf.cond(load_inputs_pred, lambda: loaded_data, generate_input_data)
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
print(threads)
# Get generated input data
data_batch_values = sess.run(data_batch, feed_dict={load_inputs_pred: False})
print(data_batch_values)
# Get input data loaded from disk
data_batch_values = sess.run(data_batch, feed_dict={load_inputs_pred: True})
print(data_batch_values)
if __name__ == '__main__':
main()
Variant 1根本不起作用,因为队列运行程序线程似乎没有运行。 print(threads)
输出类似[<Thread(Thread-1, stopped daemon 140165838264064)>, ...]
的内容。
变体2可以正常工作,print(threads)
输出类似[<Thread(Thread-1, started daemon 140361854863104)>, ...]
的内容。但由于load_input_data()
已在tf.cond
之外调用,因此即使load_inputs_pred
为False
,也会从磁盘加载批量数据。
是否可以使Variant 1工作,以便输入数据仅在load_inputs_pred
为True
时加载,而不是每次调用session.run()
时都会加载?
答案 0 :(得分:0)
如果您在加载数据时使用队列并使用批量输入进行跟踪,那么这不应该是一个问题,因为您可以指定已加载或存储在队列中的最大数量。
input = tf.WholeFileReader(somefilelist) # or another way to load data
return tf.train.batch(input,batch_size=10,capacity=100)
请点击此处了解更多详情: https://www.tensorflow.org/versions/r0.10/api_docs/python/io_ops.html#batch
还有一种替代方法可以完全跳过tf.cond。只需定义两个通过自动编码器和描述符跟踪数据的丢失,另一个通过鉴别器跟踪数据。
然后它就变成了调用
的问题sess.run(auto_loss,feed_dict)
或
sess.run(real_img_loss,feed_dict)
通过这种方式,图表只会在遇到损失的情况下运行。如果需要更多解释,请告诉我。
最后,如果您使用预加载的数据,我认为要做一个变体工作,你需要做这样的事情。
https://www.tensorflow.org/versions/r0.10/how_tos/reading_data/index.html#preloaded-data
否则,我不确定问题是诚实的。