使用索引在max_pool_with_argmax之后进行Tensorflow解池

时间:2018-09-22 13:41:35

标签: tensorflow machine-learning deep-learning

当尝试从Google的论文中实现U-SegNet时,我在使用argmax索引来实现unpooling操作时遇到了问题。

完整代码:

import tensorflow as tf


def unpool(pool, ind, ksize=[1, 2, 2, 1], name=None):
    with tf.variable_scope('name') as scope:
        input_shape = tf.shape(pool)
        output_shape = [input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]]

        flat_input_size = tf.cumprod(input_shape)[-1]
        flat_output_shape = tf.stack([output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]])

        pool_ = tf.reshape(pool, tf.stack([flat_input_size]))
        batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), dtype=ind.dtype),
                                        shape=tf.stack([input_shape[0], 1, 1, 1]))
        b = tf.ones_like(ind) * batch_range
        b = tf.reshape(b, tf.stack([flat_input_size, 1]))
        ind_ = tf.reshape(ind, tf.stack([flat_input_size, 1]))
        ind_ = tf.concat([b, ind_], 1)

        ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64))
        ret = tf.reshape(ret, tf.stack(output_shape))

        set_input_shape = pool.get_shape()
        set_output_shape = [set_input_shape[0], set_input_shape[1] * ksize[1], set_input_shape[2] * ksize[2], set_input_shape[3]]
        ret.set_shape(set_output_shape)
    return ret

with tf.Session() as sess:
    x = tf.random_normal([1, 4, 4, 1])
    y, ind = tf.nn.max_pool_with_argmax(
        x,
        ksize=[1, 2, 2, 1],
        strides=[1, 2, 2, 1],
        padding='SAME'
    )

    z = unpool(y, ind)

    x_, y_, z_ = sess.run([x, y, z])

对于批次大小1,它可以正常工作,但对于批次大小> 1,它在发生下一个问题时崩溃:

2018-09-22 16:33:57.010504: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2018-09-22 16:33:57.082638: W tensorflow/core/framework/op_kernel.cc:1275] OP_REQUIRES failed at scatter_nd_op.cc:119 : Invalid argument: Invalid indices: [2,0] = [1, 21] does not index into [4,16]
Traceback (most recent call last):
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1278, in _do_call
    return fn(*args)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1263, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1350, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Invalid indices: [2,0] = [1, 21] does not index into [4,16]
     [[Node: name/ScatterNd = ScatterNd[T=DT_FLOAT, Tindices=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"](name/concat, name/Reshape, name/Cast_1)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "tst.py", line 39, in <module>
    x_, y_, z_ = sess.run([x, y, z])
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 877, in run
    run_metadata_ptr)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1100, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1272, in _do_run
    run_metadata)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1291, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Invalid indices: [2,0] = [1, 21] does not index into [4,16]
     [[Node: name/ScatterNd = ScatterNd[T=DT_FLOAT, Tindices=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"](name/concat, name/Reshape, name/Cast_1)]]

Caused by op 'name/ScatterNd', defined at:
  File "tst.py", line 37, in <module>
    z = unpool(y, ind)
  File "tst.py", line 20, in unpool
    ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64))
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 6788, in scatter_nd
    "ScatterNd", indices=indices, updates=updates, shape=shape, name=name)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 454, in new_func
    return func(*args, **kwargs)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3155, in create_op
    op_def=op_def)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1717, in __init__
    self._traceback = tf_stack.extract_stack()

InvalidArgumentError (see above for traceback): Invalid indices: [2,0] = [1, 21] does not index into [4,16]
     [[Node: name/ScatterNd = ScatterNd[T=DT_FLOAT, Tindices=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"](name/concat, name/Reshape, name/Cast_1)]]

哪里可能有问题,我该如何解决?

解集功能取自this issue on github,但没有任何关于在那里分批进行解集的信息。

我的tf.__version__是1.10。

3 个答案:

答案 0 :(得分:0)

@ Tofik.AI您使用的是Tensorflow版本吗? 根据最新文档,这是不正确的。 我的实现:

from sympy import *
z_re = Symbol('z_re',Complex=True)
z_re_c = conjugate(z_re)

e1 = Mul(z_re,Pow(Add(z_re,Integer(-1)),Integer(-1)))
e2 = Mul(z_re,z_re_c,Pow(Add(z_re,Integer(-1)),Integer(-1)),Pow(Add(z_re_c,Integer(-1)),Integer(-1)))
e3 = Mul(z_re_c,Pow(Add(z_re_c,Integer(-1)),Integer(-1)))
e4 = Add(e1,e2,e3)

e5 =e4.factor()
e6 = fraction(e5)[0] # just the numerator

答案 1 :(得分:0)

有一个实现unpool op in CUDA的存储库。 unpool_example.py文件显示了如何使用该库。从最初的测试开始,它比在推理时组合现有的tensorflow函数快两倍(训练期间是四倍)。

只需按以下方式使用它即可:

import unpool

#pool, inds = max_pool_with_inds
unpool_layer = unpool.unpool(pool, inds,
                             output_size=[height, width],
                             name="unpool")

全部披露,我撰写了此回购书。

答案 2 :(得分:-1)

您的代码运行正常:

import tensorflow as tf

def unpool(pool, ind, ksize=[1, 2, 2, 1], name=None):
    with tf.variable_scope('name') as scope:
        input_shape = tf.shape(pool)
        output_shape = [input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]]

        flat_input_size = tf.cumprod(input_shape)[-1]
        flat_output_shape = tf.stack([output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]])

        pool_ = tf.reshape(pool, tf.stack([flat_input_size]))
        batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), dtype=ind.dtype),
                                        shape=tf.stack([input_shape[0], 1, 1, 1]))
        b = tf.ones_like(ind) * batch_range
        b = tf.reshape(b, tf.stack([flat_input_size, 1]))
        ind_ = tf.reshape(ind, tf.stack([flat_input_size, 1]))
        ind_ = tf.concat([b, ind_], 1)

        ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64))
        ret = tf.reshape(ret, tf.stack(output_shape))

        set_input_shape = pool.get_shape()
        set_output_shape = [set_input_shape[0], set_input_shape[1] * ksize[1], set_input_shape[2] * ksize[2], set_input_shape[3]]
        ret.set_shape(set_output_shape)
    return ret


batch_size=10
with tf.Session() as sess:

    x = tf.random_normal([batch_size,16,16,1])
    y, ind = tf.nn.max_pool_with_argmax(
        x,
        ksize=[1, 2, 2, 1],
        strides=[1, 2, 2, 1],
        padding='SAME'
    )

    z = unpool(y, ind)
    x_, y_, z_=sess.run([x, y, z])



aa=x_[4,:,:,0]
bb=y_[4,:,:,0]
cc=z_[4,:,:,0]

您可以更新张量流。 我正在使用tensorflow 1.12.0