当我在tensorflow中实现unpool时,tf.scatter_add中有一个奇怪的错误

时间:2018-11-08 09:27:06

标签: python tensorflow

我正在尝试使用tf.scatter_add在tensorflow中实现unpool,但是遇到一个奇怪的错误,这是我的代码:

import tensorflow as tf
import numpy as np
import random

tf.reset_default_graph()

mat = list(range(64))
random.shuffle(mat)
mat = np.array(mat)
mat = np.reshape(mat, [1,8,8,1])
M = tf.constant(mat, dtype=tf.float32)
pool1, argmax1 = tf.nn.max_pool_with_argmax(M, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
pool2, argmax2 = tf.nn.max_pool_with_argmax(pool1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
pool3, argmax3 = tf.nn.max_pool_with_argmax(pool2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')


def unpool(x, argmax, strides, unpool_shape=None, batch_size=None, name='unpool'):
    x_shape = x.get_shape().as_list()
    argmax_shape = argmax.get_shape().as_list()
    assert not(x_shape[0] is None and batch_size is None), "must input batch_size if number of batch is alterable"
    if x_shape[0] is None:
        x_shape[0] = batch_size
    if argmax_shape[0] is None:
        argmax_shape[0] = x_shape[0]
    if unpool_shape is None:
        unpool_shape = [x_shape[i] * strides[i] for i in range(4)]
    x_unpool = tf.get_variable(name=name, shape=[np.prod(unpool_shape)], initializer=tf.zeros_initializer(), trainable=False)
    argmax = tf.cast(argmax, tf.int32)
    argmax = tf.reshape(argmax, [np.prod(argmax_shape)])
    x = tf.reshape(x, [np.prod(argmax.get_shape().as_list())])
    x_unpool = tf.scatter_add(x_unpool , argmax, x)
    x_unpool = tf.reshape(x_unpool , unpool_shape)
    return x_unpool 


unpool2 = unpool(pool3, argmax3, strides=[1,2,2,1], name='unpool3')
unpool1 = unpool(unpool2, argmax2, strides=[1,2,2,1], name='unpool2')
unpool0 = unpool(unpool1, argmax1, strides=[1,2,2,1], name='unpool1')


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    mat_out = mat[:,:,:,0]
    pool1_out = sess.run(pool1)[0,:,:,0]
    pool2_out = sess.run(pool2)[0,:,:,0]
    pool3_out = sess.run(pool3)[0,:,:,0]
    argmax1_out = sess.run(argmax1)[0,:,:,0]
    argmax2_out = sess.run(argmax2)[0,:,:,0]
    argmax3_out = sess.run(argmax3)[0,:,:,0]
    unpool2_out = sess.run(unpool2)[0,:,:,0]
    unpool1_out = sess.run(unpool1)[0,:,:,0]
    unpool0_out = sess.run(unpool0)[0,:,:,0]
    print(unpool2_out)
    print(unpool1_out)
    print(unpool0_out)

输出:

[[ 0.  0.]
 [ 0. 63.]]
[[  0.   0.   0.   0.]
 [  0.   0.   0.   0.]
 [  0.   0. 126.   0.]
 [  0.   0.   0.   0.]]
[[  0.   0.   0.   0.   0.   0.   0.   0.]
 [  0.   0.   0.   0.   0.   0.   0.   0.]
 [  0.   0.   0.   0.   0.   0.   0.   0.]
 [  0.   0.   0.   0.   0.   0.   0.   0.]
 [  0.   0.   0.   0.   0.   0.   0.   0.]
 [  0.   0.   0.   0. 315.   0.   0.   0.]
 [  0.   0.   0.   0.   0.   0.   0.   0.]
 [  0.   0.   0.   0.   0.   0.   0.   0.]]

位置正确,但值错误。 unpool2是正确的,unpool1是期望值的两倍,unpool2是期望值的五倍。我不知道怎么了,有人可以告诉我如何解决此错误吗?

非常感谢。

2 个答案:

答案 0 :(得分:1)

实际上,答案很简单。为了方便起见,我重命名了一些变量,请看下面的代码:

def unpool(x, argmax, strides, unpool_shape=None, batch_size=None, name='unpool'):
    x_shape = x.get_shape().as_list()
    argmax_shape = argmax.get_shape().as_list()
    assert not(x_shape[0] is None and batch_size is None), "must input batch_size if number of batch is alterable"
    if x_shape[0] is None:
        x_shape[0] = batch_size
    if argmax_shape[0] is None:
        argmax_shape[0] = x_shape[0]
    if unpool_shape is None:
        unpool_shape = [x_shape[i] * strides[i] for i in range(4)]
    x_unpool = tf.get_variable(name=name, shape=[np.prod(unpool_shape)], initializer=tf.zeros_initializer(), trainable=False)
    argmax = tf.cast(argmax, tf.int32)
    argmax = tf.reshape(argmax, [np.prod(argmax_shape)])
    x = tf.reshape(x, [np.prod(argmax.get_shape().as_list())])
    x_unpool_add = tf.scatter_add(x_unpool , argmax, x)
    x_unpool_reshape = tf.reshape(x_unpool_add , unpool_shape)
    return x_unpool_reshape 

x_unpool_add是tf.scatter_add的操作,每次我们计算x_unpool_reshape时,都会调用x_unpool_add。因此,如果我们两次计算unpool2,则x_unpool将x加两次。在我的原始代码中,我按顺序计算unpool0,unpool1,unpool2,首先调用unpool1的x_unpool_add,然后在计算unpool2时,由于我们需要计算unpool1,因此将再次调用x_unpool_add,因此等于两次调用x_unpool_add,价值是错误的。如果我们直接计算unpool2,我们将得到正确的结果。因此,将tf.scatter_add替换为tf.scatter_update可以避免此错误。

此代码可以直观地重现:

import tensorflow as tf

t1 = tf.get_variable(name='t1', shape=[1], dtype=tf.float32, initializer=tf.zeros_initializer())
t2 = tf.get_variable(name='t2', shape=[1], dtype=tf.float32, initializer=tf.zeros_initializer())
d = tf.scatter_add(t1, [0], [1])
e = tf.scatter_add(t2, [0], d)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    d_out1 = sess.run(d)
    d_out2 = sess.run(d)
    e_out = sess.run(e)
    print(d_out1)
    print(d_out2)
    print(e_out)

输出:

[1.]
[2.]
[3.]

答案 1 :(得分:0)

使用tf.scatter_update可以避免这种情况。

import tensorflow as tf
import numpy as np
import random

tf.reset_default_graph()

mat = list(range(64))
random.shuffle(mat)
mat = np.array(mat)
mat = np.reshape(mat, [1,8,8,1])
M = tf.constant(mat, dtype=tf.float32)
pool1, argmax1 = tf.nn.max_pool_with_argmax(M, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
pool2, argmax2 = tf.nn.max_pool_with_argmax(pool1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
pool3, argmax3 = tf.nn.max_pool_with_argmax(pool2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')


def unpool(x, argmax, strides, unpool_shape=None, batch_size=None, name='unpool'):
    x_shape = x.get_shape().as_list()
    argmax_shape = argmax.get_shape().as_list()
    assert not(x_shape[0] is None and batch_size is None), "must input batch_size if number of batch is alterable"
    if x_shape[0] is None:
        x_shape[0] = batch_size
    if argmax_shape[0] is None:
        argmax_shape[0] = x_shape[0]
    if unpool_shape is None:
        unpool_shape = [x_shape[i] * strides[i] for i in range(4)]
    unpool = tf.get_variable(name=name, shape=[np.prod(unpool_shape)], initializer=tf.zeros_initializer(), trainable=False)
    argmax = tf.cast(argmax, tf.int32)
    argmax = tf.reshape(argmax, [np.prod(argmax_shape)])
    x = tf.reshape(x, [np.prod(argmax.get_shape().as_list())])
    unpool = tf.scatter_update(unpool, argmax, x)
    unpool = tf.reshape(unpool, unpool_shape)
    return unpool


unpool2 = unpool(pool3, argmax3, strides=[1,2,2,1], name='unpool3')
unpool1 = unpool(unpool2, argmax2, strides=[1,2,2,1], name='unpool2')
unpool0 = unpool(unpool1, argmax1, strides=[1,2,2,1], name='unpool1')


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    mat_out = mat[:,:,:,0]
    pool1_out = sess.run(pool1)[0,:,:,0]
    pool2_out = sess.run(pool2)[0,:,:,0]
    pool3_out = sess.run(pool3)[0,:,:,0]
    argmax1_out = sess.run(argmax1)[0,:,:,0]
    argmax2_out = sess.run(argmax2)[0,:,:,0]
    argmax3_out = sess.run(argmax3)[0,:,:,0]
    unpool2_out = sess.run(unpool2)[0,:,:,0]
    unpool1_out = sess.run(unpool1)[0,:,:,0]
    unpool0_out = sess.run(unpool0)[0,:,:,0]
    print(unpool2_out)
    print(unpool1_out)
    print(unpool0_out)

输出:

[[ 0.  0.]
 [ 0. 63.]]
[[ 0.  0.  0.  0.]
 [ 0.  0.  0.  0.]
 [ 0.  0.  0. 63.]
 [ 0.  0.  0.  0.]]
[[ 0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0. 63.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.]]