我试图在python中开始使用TensorFlow,使用批量规范化构建一个简单的CNN。但是当我创建一个新图表来运行时,BN就会发生异常。
我的密码如下
**# exception here**
def batch_norm(x, beta, gamma, phase_train, scope='bn', decay=0.9, eps=1e-5):
with tf.variable_scope(scope):
batch_mean, batch_var = tf.nn.moments(x, [0], name='moments')
ema = tf.train.ExponentialMovingAverage(decay=decay)
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = tf.cond(phase_train, mean_var_with_update, lambda: (ema.average(batch_mean), ema.average(batch_var)))
normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)
return normed
培训代码:
# start training
output = conv2d_net()
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=output, labels=Y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.002).minimize(loss)
predict = tf.reshape(output, [-1, MAX_CAPTCHA, CHAR_SET_LEN])
max_idx_p = tf.argmax(predict, 2)
max_idx_l = tf.argmax(tf.reshape(Y, [-1, MAX_CAPTCHA, CHAR_SET_LEN]), 2)
correct_pred = tf.equal(max_idx_p, max_idx_l)
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step = 0
while True:
batch_x, batch_y = get_next_batch(64)
_, loss_ = sess.run([optimizer, loss],
feed_dict={X: batch_x, Y: batch_y, keep_prob: 0.75, train_phase: True})
print(step, loss_)
if step % 10 == 0 and step != 0:
batch_x_test, batch_y_test = get_next_batch(100)
acc = sess.run(accuracy,
feed_dict={X: batch_x_test, Y: batch_y_test, keep_prob: 1., train_phase: False})
print("step %s,accuracy:%s" % (step, acc))
if acc > 0.05:
# stop training and save parameters in layer
result_weights['wc1'] = weights['wc1'].eval(sess)
...
break
step += 1
为导出创建新图表:
EXPORT_DIR = './model'
if os.path.exists(EXPORT_DIR):
shutil.rmtree(EXPORT_DIR)
g = tf.Graph()
with g.as_default():
x_2 = tf.placeholder(tf.float32, shape=[None, IMAGE_HEIGHT * IMAGE_WIDTH], name="input")
x_image = tf.reshape(x_2, shape=[-1, IMAGE_HEIGHT, IMAGE_WIDTH, 1])
# fill trained parameters and create new cnn layers
WC1 = tf.constant(result_weights['wc1'], name="WC1")
...
**# crash here!!!**
CONV1 = conv2d(WC1, BC1, x_image, tf.constant(0.0, shape=[32]),
tf.random_normal(shape=[32], mean=1.0, stddev=0.02), scope='BN_1')
OUTPUT = tf.add(tf.matmul(FULL1, W_OUT), B_OUT)
OUTPUT = tf.nn.sigmoid(OUTPUT, name="output")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
graph_def = g.as_graph_def()
tf.train.write_graph(graph_def, EXPORT_DIR, 'phone_model_graph.pb', as_text=True)
我最后创建了一个新图表。该异常意味着它在旧训练图中使用了不正确的参数。怎么解释呢?
非常感谢!
我在fuction conv2d中调用batch_norm。似乎没有张量传递给新图。
def conv2d(w, b, x, tf_constant, tf_random_normal, scope, keep_p=1., phase=tf.constant(False)):
out = tf.nn.bias_add(tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME'), b)
out = batch_norm(out, tf_constant, tf_random_normal, phase, scope=scope)
out = tf.nn.relu(out)
out = tf.nn.max_pool(out, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
out = tf.nn.dropout(out, keep_p)
return out
答案 0 :(得分:0)
我最后创建了一个新图表。
这是关键声明:在创建新图表时,无法使用旧图表中的任何张量。请参阅this question中的详细说明。根据堆栈跟踪,传递给batch_norm
的至少一个张量是在g.as_default()
之前定义的,这就是为什么张量流崩溃的原因。从您的代码段开始,我不清楚batch_norm
的确切调用方式,因此我无法说出哪一个。
您可以打印x.graph
和g
并检查这些值是否不同来检查此假设。为了避免这个问题,您既可以在一个图形中完成所有工作(这是推荐的方式),也可以在不同的python范围内定义两个图形,从而无法在两个图形中意外地重用相同的python变量。