在张量流上再现空间变换器网络

时间:2018-06-02 13:32:29

标签: tensorflow deep-learning

我克隆spatial-transformer-network以使用tensorflow 1.8.0重现STN。我从here下载了mnist_cluttered_60x60_6distortions.npz数据集 当我运行main.py时,会发生以下错误:

Loading the data...
Building ConvNet...
Traceback (most recent call last):
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 517, in make_tensor_proto
    str_values = [compat.as_bytes(x) for x in proto_values]
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 517, in <listcomp>
    str_values = [compat.as_bytes(x) for x in proto_values]
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/compat.py", line 67, in as_bytes
    (bytes_or_text,))
TypeError: Expected binary or unicode string, got -1

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "main.py", line 422, in <module>
    main()
  File "main.py", line 275, in main
    pool3_flat, pool3_size = Flatten(pool3)
  File "/home/LiChenyang/spatial-transformer-network/utils/layer_utils.py", line 85, in Flatten
    layer_flat = tf.reshape(layer, [-1, num_features])
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 6113, in reshape
    "Reshape", tensor=tensor, shape=shape, name=name)
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 513, in _apply_op_helper
    raise err
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 510, in _apply_op_helper
    preferred_dtype=default_dtype)
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1104, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 235, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 214, in constant
    value, dtype=dtype, shape=shape, verify_shape=verify_shape))
  File "/home/xinlab/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 521, in make_tensor_proto
    "supported type." % (type(values), values))
TypeError: Failed to convert object of type <class 'list'> to Tensor. Contents: [-1, None]. Consider casting elements to a supported type.
LiChenyang@ubuntu:~/spatial-transformer-network$

我猜这个错误是由于目标形状的最后一个维度在执行代码时是None(即[-1,None]):layer_flat = tf.reshape(layer, [-1, num_features])

我尝试在h_trans中打印pool3main(),然后我得到:

Tensor("AddN:0", shape=(?, ?, ?, 1), dtype=float32)
Tensor("Relu_5:0", shape=(?, ?, ?, 128), dtype=float32)

h_trans的第二和第三维是'?',这使num_features中的Flatten()无。

为了验证这个想法,我在./utils/layer_utils.py中添加了一个新函数:

def Flatten_test(layer):
    """
    Handy function for flattening the result of a conv2D or
    maxpool2D to be used for a fully-connected (affine) layer.
    """
    # layer_shape = layer.get_shape()
    # # num_features = tf.reduce_prod(tf.shape(layer)[1:])
    # num_features = layer_shape[1:].num_elements()

    num_features = 8 * 8 * 128
    layer_flat = tf.reshape(layer, [-1, num_features])

    return layer_flat, num_features

手动计算num_features并更改:

pool3_flat, pool3_size = Flatten(pool3)

pool3_flat, pool3_size = Flatten_test(pool3)

最后,我可以成功运行main.py

然而,我仍然对此感到困惑。为什么其他人可以直接运行此代码。它是由不同tensorflow版本之间的差异造成的吗?

0 个答案:

没有答案