我克隆spatial-transformer-network以使用tensorflow 1.8.0重现STN。我从here下载了mnist_cluttered_60x60_6distortions.npz
数据集
当我运行main.py
时,会发生以下错误:
Loading the data...
Building ConvNet...
Traceback (most recent call last):
File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 517, in make_tensor_proto
str_values = [compat.as_bytes(x) for x in proto_values]
File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 517, in <listcomp>
str_values = [compat.as_bytes(x) for x in proto_values]
File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/compat.py", line 67, in as_bytes
(bytes_or_text,))
TypeError: Expected binary or unicode string, got -1
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "main.py", line 422, in <module>
main()
File "main.py", line 275, in main
pool3_flat, pool3_size = Flatten(pool3)
File "/home/LiChenyang/spatial-transformer-network/utils/layer_utils.py", line 85, in Flatten
layer_flat = tf.reshape(layer, [-1, num_features])
File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 6113, in reshape
"Reshape", tensor=tensor, shape=shape, name=name)
File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 513, in _apply_op_helper
raise err
File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 510, in _apply_op_helper
preferred_dtype=default_dtype)
File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1104, in internal_convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 235, in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)
File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 214, in constant
value, dtype=dtype, shape=shape, verify_shape=verify_shape))
File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 521, in make_tensor_proto
"supported type." % (type(values), values))
TypeError: Failed to convert object of type <class 'list'> to Tensor. Contents: [-1, None]. Consider casting elements to a supported type.
LiChenyang@ubuntu:~/spatial-transformer-network$
我猜这个错误是由于目标形状的最后一个维度在执行代码时是None(即[-1,None]):layer_flat = tf.reshape(layer, [-1, num_features])
。
我尝试在h_trans
中打印pool3
和main()
,然后我得到:
Tensor("AddN:0", shape=(?, ?, ?, 1), dtype=float32)
Tensor("Relu_5:0", shape=(?, ?, ?, 128), dtype=float32)
h_trans
的第二和第三维是'?',这使num_features
中的Flatten()
无。
为了验证这个想法,我在./utils/layer_utils.py
中添加了一个新函数:
def Flatten_test(layer):
"""
Handy function for flattening the result of a conv2D or
maxpool2D to be used for a fully-connected (affine) layer.
"""
# layer_shape = layer.get_shape()
# # num_features = tf.reduce_prod(tf.shape(layer)[1:])
# num_features = layer_shape[1:].num_elements()
num_features = 8 * 8 * 128
layer_flat = tf.reshape(layer, [-1, num_features])
return layer_flat, num_features
手动计算num_features
并更改:
pool3_flat, pool3_size = Flatten(pool3)
到
pool3_flat, pool3_size = Flatten_test(pool3)
最后,我可以成功运行main.py
。
然而,我仍然对此感到困惑。为什么其他人可以直接运行此代码。它是由不同tensorflow版本之间的差异造成的吗?