如何使用while_loop和tf.layers.batch_normalization进行培训?

时间:2018-04-23 12:21:36

标签: python while-loop batch-normalization

我需要在while循环体中添加batch_normalization图层,但是当我训练网时它会崩溃。如果我删除x = tf.layers.batch_normalization(x, training=flag),一切都会好的。我可以在循环体中使用高API吗?我不想使用tf.nn.tf.nn.batch_normalization,因为这是一个简单的例子,我的网络要复杂得多。

import tensorflow as tf
from data_pre import get_data

data, labels = get_data(
    ['../UCR_TS_Archive_2015/ItalyPowerDemand/ItalyPowerDemand_TRAIN'], 24, 2,True, 0, 2)  #pylint: disable=line-too-long

flag = True

def cond(i, x):
    return i < 1

def body(i, x):
    x = tf.layers.conv1d(x, 1, 7, padding='same')
    x = tf.layers.batch_normalization(x, training=flag)
    x = tf.nn.relu(x)
    return i + 1, x

_, y = tf.while_loop(cond, body, [0, data], back_prop=False)

y = tf.layers.flatten(y)
logits = tf.layers.dense(y, 2)

loss = tf.losses.mean_squared_error(labels, logits)
optimizer = tf.train.AdamOptimizer()
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss, tf.train.get_global_step())

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for _ in range(10):
        sess.run(train_op)
    coord.request_stop()
    coord.join(threads)

以下是错误信息:

Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1327, in _do_call
    return fn(*args)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1312, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1420, in _call_tf_sessionrun
    status, run_metadata)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py", line 516, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: The node 'gradients/mean_squared_error/div_grad/Neg' has inputs from different frames. The input 'while/batch_normalization/AssignMovingAvg_1' is in frame 'while/while_context'. The input 'one_hot' is in frame ''.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "./test.py", line 40, in <module>
    sess.run(train_op)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 905, in run
    run_metadata_ptr)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1140, in _run
    feed_dict_tensor, options, run_metadata)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1321, in _do_run
    run_metadata)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1340, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: The node 'gradients/mean_squared_error/div_grad/Neg' has inputs from different frames. The input 'while/batch_normalization/AssignMovingAvg_1' is in frame 'while/while_context'. The input 'one_hot' is in frame ''. 

1 个答案:

答案 0 :(得分:1)

我从github得到了帮助。如果您遇到类似问题,可以从The net using while_loop with batch_normalization can't train

获得帮助