(Tensorflow)我是否正确使用批处理规范化?

时间:2020-03-23 11:59:41

标签: python tensorflow batch-normalization

正在为小批量计算的列车数据的相位为“相位= 1”。

        cost_val, hy_val, _ = sess.run(
        [cost, hypothesis, optimizer], feed_dict = {X: x_data[i*batch_size:(i+1)*batch_size], Y: y_data[i*batch_size:(i+1)*batch_size], **phase : 1**})

在完成小批量计算后打印火车错误和验证错误时,火车数据和验证数据的相位为'phase = 0'。

if epoch % 10 == 0:
        rms_per_train, mae_train = sess.run([rms_per,mae], feed_dict={X: x_data, Y: y_data, **phase : 0**})

        cost_test, rms_test, diff_max_test, diff_min_test, rms_per_test, per_max_test, per_min_test, mae_test = sess.run(
        [cost, rms, diff_max, diff_min, rms_per, per_max, per_min, mae], feed_dict={X: x_test, Y: y_test, **phase : 0**})

这是正确的用法吗?

######model_example#######
conv1 = tf.layers.conv2d(inputs = x_input, filters = num_filter, kernel_size=[3,3], padding="SAME", strides=1, name = name_conv1,
                         kernel_initializer=tf.contrib.layers.xavier_initializer(), kernel_regularizer=regularizer)
bn1 = tf.layers.batch_normalization(inputs = conv1, name = name_bn1)
relu1 = tf.nn.relu(bn1, name = name_relu1)

'
'
'

# define cost/loss & optimizer
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)

'
'
'

for epoch in range(30001):
    step = epoch
    p = np.random.permutation(len(x_data))
    x_data=x_data[p]
    y_data=y_data[p]
    total_batch = 49000
    avg_cost_test = avg_cost_train = 0

    for i in range(total_batch):
        cost_val, hy_val, _ = sess.run(
        [cost, hypothesis, optimizer], feed_dict = {X: x_data[i*batch_size:(i+1)*batch_size], Y: y_data[i*batch_size:(i+1)*batch_size], **phase : 1**})

        avg_cost_train += cost_val / total_batch

    if epoch % 10 == 0:
        rms_per_train, mae_train = sess.run([rms_per,mae], feed_dict={X: x_data, Y: y_data, **phase : 0**})

        cost_test, rms_test, diff_max_test, diff_min_test, rms_per_test, per_max_test, per_min_test, mae_test = sess.run(
        [cost, rms, diff_max, diff_min, rms_per, per_max, per_min, mae], feed_dict={X: x_test, Y: y_test, **phase : 0**})

        out = open('M2NR_resnet_34L_64W_0.001_reg_0.0001.out','a')
        print(epoch,'epoch' , ' cost_train', avg_cost_train, "cost_test ", cost_test, per_max_test, per_min_test, rms_per_test, rms_per_train, mae_test, mae_train)
        out.write('%d Cost_train %e  Cost_test %e %e %e %e %e %e %e\n' %(epoch,avg_cost_train,cost_test,per_max_test,per_min_test,rms_per_test, rms_per_train, mae_test, mae_train))
        out.close ()
        saver.save(sess, 'model/M2NR_resnet_34L_64W_0.001_reg_0.0001/M2NR_resnet_34L_64W_0.001_reg_0.0001', global_step=epoch, write_meta_graph=False)

0 个答案:

没有答案