我在CNN功能之上使用RNN实现OCR。我无法理解tf.space_to_batch_nd()和tf.batch_to_space_nd()的机制。 我需要在输入层的每个时间(1-d dim)切片上应用完全连接的层,以将张量等级从4减少到3 我通过tf.reshape()实现了这个。 输入张量形状为[1,84,7,128],输出应为[1,84,128] 我目前的实施是:
with tf.variable_scope('dim_redux') as scope:
conv_out_shape = tf.shape(net)
print("Conv out:", str(net))
w_fc1 = weight_variable([7 * 128, 128])
b_fc1 = bias_variable([128])
conv_layer_flat = tf.reshape(net, [-1, 7 * 127])
features = tf.matmul(conv_layer_flat, w_fc1) + b_fc1
features = lrelu(h_bn3)
features = tf.reshape(features, [batch_size, int(84), CONV_FC_OUTPUT])
使用tf.space_to_batch_nd()和tf.batch_to_space_nd():
with tf.variable_scope('dim_redux') as scope:
net = tf.space_to_batch_nd(net, block_shape=[84, 1], paddings=[[0, 0], [0, 0]])
print(net)
net = tf.contrib.layers.flatten(net)
net = tf.contrib.layers.fully_connected(net, CONV_FC_OUTPUT, biases_initializer=tf.zeros_initializer())
net = lrelu(net)
print(net)
net = tf.batch_to_space_nd(net, block_shape=[84, 128], crops=[[0, 0], [0, 0]])
features = net
看起来块形状应该是[1,7],但只有这个值[84,1] tf.space_to_batch_nd()才能返回具有正确形状的张量[84,1,7,128]。 对于当前的params,我发现错误:
File "/Users/akislinskiy/tag_price_ocr/ocr.py", line 336, in convolutional_layers
net = tf.batch_to_space_nd(net, block_shape=[84, 128], crops=[[0, 0], [0, 0]])
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 318, in batch_to_space_nd
name=name)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
op_def=op_def)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2632, in create_op
set_shapes_for_outputs(ret)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1911, in set_shapes_for_outputs
shapes = shape_func(op)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1861, in call_with_requiring
return call_cpp_shape_fn(op, require_shape_fn=True)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/common_shapes.py", line 595, in call_cpp_shape_fn
require_shape_fn)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/common_shapes.py", line 659, in _call_cpp_shape_fn_impl
raise ValueError(err.message)
ValueError: Shape must be at least rank 3 but is rank 2 for 'convolutions/dim_redux/BatchToSpaceND' (op: 'BatchToSpaceND') with input shapes: [84,128], [2], [2,2].