Tensorflow:卷积中的无效参数错误

时间:2017-04-06 15:03:46

标签: tensorflow deep-learning convolution

我试图运行这段Python代码并且似乎无法绕过错误:

tf.nn.conv2d(tf.reshape(x, [5, 5]), tf.reshape(wt, [3, 3]), strides=[1, 1],  padding='SAME')

这里,x是来自(5,5)numpy数组的tf.Variable,w是来自(3,3)numpy数组的变量。

我得到的错误是:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
C:\Anaconda3\lib\site-packages\tensorflow\python\framework\common_shapes.py in _call_cpp_shape_fn_impl(op, input_tensors_needed, input_tensors_as_shapes_needed, debug_python_shape_fn, require_shape_fn)
    669           node_def_str, input_shapes, input_tensors, input_tensors_as_shapes,
--> 670           status)
    671   except errors.InvalidArgumentError as err:

C:\Anaconda3\lib\contextlib.py in __exit__(self, type, value, traceback)
     65             try:
---> 66                 next(self.gen)
     67             except StopIteration:

C:\Anaconda3\lib\site-packages\tensorflow\python\framework\errors_impl.py in raise_exception_on_not_ok_status()
    468           compat.as_text(pywrap_tensorflow.TF_Message(status)),
--> 469           pywrap_tensorflow.TF_GetCode(status))
    470   finally:

InvalidArgumentError: Shape must be rank 4 but is rank 2 for 'Conv2D_19' (op: 'Conv2D') with input shapes: [5,5], [3,3].

1 个答案:

答案 0 :(得分:0)

使用tf.nn.conv2d。您的输入和过滤器都应转换为4D。此外,strides应为1-D of length 4(输入的每个维度的滑动窗口)。以下摘自documentation

  

给定输入张量的形状[batch,in_height,in_width,   in_channels]和形状的过滤器/内核张量[filter_height,   filter_width,in_channels,out_channels],此操作执行   以下内容:

     

将滤镜展平为具有形状的二维矩阵[filter_height *   filter_width * in_channels,output_channels]。提取图像补丁   从输入张量形成虚拟张量的形状[批量,   out_height,out_width,filter_height * filter_width * in_channels]。   对于每个补丁,右对乘滤波器矩阵和图像补丁   矢量。

您可以:tf.reshape(x, [1, 5, 5, 1])代表数据,tf.reshape(wt, [3, 3, 1, 1])代表过滤器,strides=[1, 1, 1, 1]。这导致:

tf.nn.conv2d(tf.reshape(x, [1, 5, 5, 1]), tf.reshape(wt, [3, 3, 1, 1]), strides=[1, 1, 1, 1],  padding='SAME')