使用tf.cond()来提供我的图表以进行培训和验证

时间:2017-09-25 14:23:55

标签: tensorflow tensorflow-gpu

在我的TensorFlow代码中,我希望我的网络从两个9999999999对象中的一个获取输入,具体取决于我是否要进行培训或测试。 我写的图形构造代码的一部分如下:

StagingArea

我使用with tf.device("/gpu:0"): for i in range(numgpus): with tf.variable_scope(tf.get_variable_scope(), reuse=i>0) as vscope: with tf.device('/gpu:{}'.format(i)): with tf.name_scope('GPU-Tower-{}'.format(i)) as scope: phase = tf.get_variable("phase", [], initializer=tf.zeros_initializer(),dtype=tf.uint8, trainable=False) phaseassigntest = phase.assign(1) phaseassigntrain = phase.assign(0) phasetest = tf.equal(phase, 0) is_training = tf.cond(phasetest, lambda: tf.constant(True), lambda: tf.constant(False)) trainstagingarea = tf.contrib.staging.StagingArea([tf.float32, tf.int32], shapes=[[trainbatchsize, 3, 221, 221], [trainbatchsize]], capacity=20) putoptrain = trainstagingarea.put(train_iterator.get_next()) trainputop.append(putoptrain) getoptrain = trainstagingarea.get() traingetop.append(getoptrain) trainclearop = trainstagingarea.clear() trainstageclear.append(trainclearop) trainsizeop = trainstagingarea.size() trainstagesize.append(trainsizeop) valstagingarea = tf.contrib.staging.StagingArea([tf.float32, tf.int32], shapes=[[valbatchsize, 3, 221, 221], [valbatchsize]], capacity=20) putopval = valstagingarea.put(val_iterator.get_next()) valputop.append(putopval) getopval = valstagingarea.get() valgetop.append(getopval) valclearop = valstagingarea.clear() valstageclear.append(valclearop) valsizeop = valstagingarea.size() valstagesize.append(valsizeop) #elem = valgetop[i] elem = tf.cond(is_training,lambda: traingetop[i],lambda: valgetop[i]) img = elem[0] label = elem[1] labelonehot = tf.one_hot(label, depth=numclasses) net, networksummaries = overfeataccurate(img,numclasses=numclasses, phase=is_training) 来确保网络由两个tf.cond对象中的一个提供。一个用于培训,另一个用于验证。 现在,当我尝试按如下方式执行图形时,我没有得到任何结果,事实上代码只是挂起而且我必须终止这个过程。

StagingArea

而不是with tf.Session(graph=g,config=config) as sess: sess.run(init_op) sess.run(tf.local_variables_initializer()) sess.run(val_initialize) for i in range(20): sess.run(valputop) print(sess.run(valstagesize)) writer = tf.summary.FileWriter('.', graph=tf.get_default_graph()) epoch = 0 iter = 0 print("Performing Validation") sess.run(phaseassigntest) saver = tf.train.Saver() while(epoch<10): time_init = time.time() while True: try: [val_accu, _, summaries] = sess.run([towervalidation, towervalidationupdateop,validation_summary_op]) print(val_accu) 我直接指定tf.cond(),代码工作得很好。 我在这里错过了一些东西吗?

根据我是否要进行培训或测试来提供网络的正确方法是什么?

注意即使我将elem = valgetop[i]设置为1,错误也不会消失。

1 个答案:

答案 0 :(得分:2)

您的问题

您认为tf.cond做什么

根据该标志,执行将traingetop [i]或valgetop [i]放入elem张量所需的内容。

tf.cond实际做了什么

执行获取两者 traingetop [i]和valgetop [i]所需的内容,然后将其中一个传递到elem张量。

因此

它永远悬挂的原因是因为它正在等待将一个元素添加到您的训练临时区域(以便它可以获取该元素并将其丢弃)。你没有意识到这就是它正在做的事情,你原谅了;它实际上非常违反直觉。文档非常不清楚如何处理这个问题。

推荐的解决方案(通过Tensorflow文档)

如果您真的需要队列在同一个图表中,那么您需要制作两份ENTIRE图表副本,一份由您的训练临时区域提供,另一份由您的验证暂存区域提供。然后,您只需在sess.run电话中使用相关张量即可。我建议创建一个获取队列输出张量的函数,并返回model_output张量。现在您有train_time_output张量和validation_time_output张量,您可以在sess.run中选择要执行的张量。

警告

您需要确保实际上没有创建新变量以配合这些新操作。要做 ,请查看有关variables的最新文档。看起来他们已经从v0.12简化了它,它基本上归结为使用tf.get_variable而不是tf.Variable来创建变量。

我喜欢的工作

虽然这是推荐的解决方案(AFAIK),但对我来说却非常不满意;你正在图上创建一组其他操作,恰好使用相同的权重。通过滥用列车时间和测试/验证时间之间的分离似乎存在很多程序员错误的可能性(导致模型在这些时间意外地起作用)。更差;它没有解决tf.cond要求两个分支的输入值的问题,它只是强迫你复制整个图形,这并不总是可行。

我更喜欢在图表中没有这样的队列,并将模型视为一个函数,可以提供一个示例,而无需关心它的来源。也就是说,我将使用tf.placeholder作为输入来实例化模型,并且在执行时我将使用feed_dict来实际提供值。它会起到像这样的功能

#inside main training loop
if time_to_train:
    example = sess.run(traingettop)
else:
    example = sess.run(valgettop)
result = sess.run(model_output, {input_placeholder: example})

注意您可以使用feed_dict为模型中任何位置的任何张量提供任何值,这非常有用。因此,您可以更改由于tf.cond始终需要输入的任何模型定义,例如:

a = tf.constant(some_value)
b = tf.placeholder(tf.float32)
flag = tf.placeholder(tf.bool, [])
one_of_them = tf.cond(flag, a, b)
model_output = build_graph(one_of_them)

进入一个没有的定义,如:

a = tf.constant(some_value)
model_output = build_graph(a)

请记住,您始终可以覆盖执行时a的内容:

# In main training loop,
sess.run(train_op, {a: some_other_value})

这实质上将条件推送到本地python土地。在您的代码中,您最终可能会遇到以下内容:

if condition_satisfied:
    sess.run(train_op, {a:some_other_value})
else:
    sess.run(train_op)

性能问题

如果你在一台机器上使用tensorflow,那么这个解决方案的实际上没有性能成本,因为放入example python变量的numpy数组实际上仍然是存储在GPU上。

如果您以分布式方式使用tensorflow,那么此解决方案会破坏您的性能;它需要将示例从它所在的任何机器发送到主设备,以便它可以将其发回。