我正在尝试在TensorFlow中实现一个输入管道,由于不同网络部分的多个权重更新,该管道保存了多次运行的输入批处理。
我认为我可以用条件包装输入管道:
# flag to skip image fetch
forwarding_network = tf.placeholder(tf.bool, [], name='forwarding_network')
input_images = None # image queue from input pipeline, must be set in real
input_labels = None # label queue from input pipeline, must be set in real
INPUT_HEIGHT = 64 # Height of the images/labels
WIDTH_HEIGHT = 64 # Width of the images/labels
# Fetch new batch from input pipeline
def forwardIR():
image_batch_fetch, label_batch_fetch = tf.train.batch([input_images, input_labels], \
batchsize=32, capacity=64)
with tf.variable_scope('im_reader_forward'):
image_batch = tf.get_variable("image_batch ", shape=[32, INPUT_HEIGHT, INPUT_WIDTH, 3], \
dtype=tf.float32, trainable=False, \
initializer=tf.constant_initializer(0.0))
image_batch = tf.assign(image_batch, image_batch_fetch)
label_batch = tf.get_variable("label_batch ", shape=[32, INPUT_HEIGHT, INPUT_WIDTH, 1], \
dtype=tf.uint8, trainable=False, \
initializer=tf.constant_initializer(0.0))
label_batch = tf.assign(label_batch, label_batch_fetch)
return image_batch, label_batch
# Hold last batch, no new fetch from pipeline
def holdIR():
with tf.variable_scope('im_reader_forward', reuse=True):
return tf.get_variable('image_batch', dtype=tf.float32), \
tf.get_variable('label_batch', dtype=tf.uint8)
# Switch: If forwarding_network == True, fetch new images from queue; else not)
image_batch, label_batch = tf.cond(forwarding_network, lambda: forwardIR(), lambda: holdIR())
# calculate loss with batch
net = Model(image_batch)
loss = net.predict()
我的问题是,培训开始时没有任何错误或失败,但没有任何反应。变量和网络操作之间可能没有联系吗?条件的输出直接输入到网络模型中。
答案 0 :(得分:0)
好吧,它比我想象的容易得多。 -.-:D
通过首先运行tf会话来评估获取的图像/标签部分,从而解决了该问题,并通过占位符将输出提供给训练迭代。
## define input pipeline, network, loss calculation, session, ...
image_batch_out, label_batch_out = sess.run([image_batch_ir, label_batch_ir])
feed_dict = { image_batch : image_batch_out, label_batch : label_batch_out }
loss_1, _ = sess.run([loss_val_1, train_op_1], feed_dict=feed_dict)
loss_2, _ = sess.run([loss_val_2, train_op_2], feed_dict=feed_dict)
loss_3, _ = sess.run([loss_val_3, train_op_3], feed_dict=feed_dict)
注释中根本没有提到任何变量。 :)
答案 1 :(得分:0)
要添加到答案中,您可以将image_batch_ir
和label_batch_ir
张量直接馈送到使用占位符作为输入的操作项。
例如,如果您的旧代码是:
image_batch_ir, label_batch_ir = ...
image_batch = tf.placeholder(...)
label_batch = tf.placeholder(...)
loss_val = some_ops(image_batch, label_batch)
image_batch_out, label_batch_out = sess.run([image_batch_ir, label_batch_ir])
feed_dict = { image_batch : image_batch_out, label_batch : label_batch_out }
loss = sess.run([loss_val], feed_dict=feed_dict)
您可以改为:
image_batch_ir, label_batch_ir = ...
loss_val = some_ops(image_batch_ir, label_batch_ir)
loss = sess.run([loss_val], feed_dict=feed_dict)