我正在尝试使用tf.scatter_add在tensorflow中实现unpool,但是遇到一个奇怪的错误,这是我的代码:
import tensorflow as tf
import numpy as np
import random
tf.reset_default_graph()
mat = list(range(64))
random.shuffle(mat)
mat = np.array(mat)
mat = np.reshape(mat, [1,8,8,1])
M = tf.constant(mat, dtype=tf.float32)
pool1, argmax1 = tf.nn.max_pool_with_argmax(M, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
pool2, argmax2 = tf.nn.max_pool_with_argmax(pool1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
pool3, argmax3 = tf.nn.max_pool_with_argmax(pool2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
def unpool(x, argmax, strides, unpool_shape=None, batch_size=None, name='unpool'):
x_shape = x.get_shape().as_list()
argmax_shape = argmax.get_shape().as_list()
assert not(x_shape[0] is None and batch_size is None), "must input batch_size if number of batch is alterable"
if x_shape[0] is None:
x_shape[0] = batch_size
if argmax_shape[0] is None:
argmax_shape[0] = x_shape[0]
if unpool_shape is None:
unpool_shape = [x_shape[i] * strides[i] for i in range(4)]
x_unpool = tf.get_variable(name=name, shape=[np.prod(unpool_shape)], initializer=tf.zeros_initializer(), trainable=False)
argmax = tf.cast(argmax, tf.int32)
argmax = tf.reshape(argmax, [np.prod(argmax_shape)])
x = tf.reshape(x, [np.prod(argmax.get_shape().as_list())])
x_unpool = tf.scatter_add(x_unpool , argmax, x)
x_unpool = tf.reshape(x_unpool , unpool_shape)
return x_unpool
unpool2 = unpool(pool3, argmax3, strides=[1,2,2,1], name='unpool3')
unpool1 = unpool(unpool2, argmax2, strides=[1,2,2,1], name='unpool2')
unpool0 = unpool(unpool1, argmax1, strides=[1,2,2,1], name='unpool1')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
mat_out = mat[:,:,:,0]
pool1_out = sess.run(pool1)[0,:,:,0]
pool2_out = sess.run(pool2)[0,:,:,0]
pool3_out = sess.run(pool3)[0,:,:,0]
argmax1_out = sess.run(argmax1)[0,:,:,0]
argmax2_out = sess.run(argmax2)[0,:,:,0]
argmax3_out = sess.run(argmax3)[0,:,:,0]
unpool2_out = sess.run(unpool2)[0,:,:,0]
unpool1_out = sess.run(unpool1)[0,:,:,0]
unpool0_out = sess.run(unpool0)[0,:,:,0]
print(unpool2_out)
print(unpool1_out)
print(unpool0_out)
输出:
[[ 0. 0.]
[ 0. 63.]]
[[ 0. 0. 0. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 126. 0.]
[ 0. 0. 0. 0.]]
[[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 315. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]]
位置正确,但值错误。 unpool2是正确的,unpool1是期望值的两倍,unpool2是期望值的五倍。我不知道怎么了,有人可以告诉我如何解决此错误吗?
非常感谢。
答案 0 :(得分:1)
实际上,答案很简单。为了方便起见,我重命名了一些变量,请看下面的代码:
def unpool(x, argmax, strides, unpool_shape=None, batch_size=None, name='unpool'):
x_shape = x.get_shape().as_list()
argmax_shape = argmax.get_shape().as_list()
assert not(x_shape[0] is None and batch_size is None), "must input batch_size if number of batch is alterable"
if x_shape[0] is None:
x_shape[0] = batch_size
if argmax_shape[0] is None:
argmax_shape[0] = x_shape[0]
if unpool_shape is None:
unpool_shape = [x_shape[i] * strides[i] for i in range(4)]
x_unpool = tf.get_variable(name=name, shape=[np.prod(unpool_shape)], initializer=tf.zeros_initializer(), trainable=False)
argmax = tf.cast(argmax, tf.int32)
argmax = tf.reshape(argmax, [np.prod(argmax_shape)])
x = tf.reshape(x, [np.prod(argmax.get_shape().as_list())])
x_unpool_add = tf.scatter_add(x_unpool , argmax, x)
x_unpool_reshape = tf.reshape(x_unpool_add , unpool_shape)
return x_unpool_reshape
x_unpool_add是tf.scatter_add的操作,每次我们计算x_unpool_reshape时,都会调用x_unpool_add。因此,如果我们两次计算unpool2,则x_unpool将x加两次。在我的原始代码中,我按顺序计算unpool0,unpool1,unpool2,首先调用unpool1的x_unpool_add,然后在计算unpool2时,由于我们需要计算unpool1,因此将再次调用x_unpool_add,因此等于两次调用x_unpool_add,价值是错误的。如果我们直接计算unpool2,我们将得到正确的结果。因此,将tf.scatter_add替换为tf.scatter_update可以避免此错误。
此代码可以直观地重现:
import tensorflow as tf
t1 = tf.get_variable(name='t1', shape=[1], dtype=tf.float32, initializer=tf.zeros_initializer())
t2 = tf.get_variable(name='t2', shape=[1], dtype=tf.float32, initializer=tf.zeros_initializer())
d = tf.scatter_add(t1, [0], [1])
e = tf.scatter_add(t2, [0], d)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
d_out1 = sess.run(d)
d_out2 = sess.run(d)
e_out = sess.run(e)
print(d_out1)
print(d_out2)
print(e_out)
输出:
[1.]
[2.]
[3.]
答案 1 :(得分:0)
使用tf.scatter_update可以避免这种情况。
import tensorflow as tf
import numpy as np
import random
tf.reset_default_graph()
mat = list(range(64))
random.shuffle(mat)
mat = np.array(mat)
mat = np.reshape(mat, [1,8,8,1])
M = tf.constant(mat, dtype=tf.float32)
pool1, argmax1 = tf.nn.max_pool_with_argmax(M, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
pool2, argmax2 = tf.nn.max_pool_with_argmax(pool1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
pool3, argmax3 = tf.nn.max_pool_with_argmax(pool2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
def unpool(x, argmax, strides, unpool_shape=None, batch_size=None, name='unpool'):
x_shape = x.get_shape().as_list()
argmax_shape = argmax.get_shape().as_list()
assert not(x_shape[0] is None and batch_size is None), "must input batch_size if number of batch is alterable"
if x_shape[0] is None:
x_shape[0] = batch_size
if argmax_shape[0] is None:
argmax_shape[0] = x_shape[0]
if unpool_shape is None:
unpool_shape = [x_shape[i] * strides[i] for i in range(4)]
unpool = tf.get_variable(name=name, shape=[np.prod(unpool_shape)], initializer=tf.zeros_initializer(), trainable=False)
argmax = tf.cast(argmax, tf.int32)
argmax = tf.reshape(argmax, [np.prod(argmax_shape)])
x = tf.reshape(x, [np.prod(argmax.get_shape().as_list())])
unpool = tf.scatter_update(unpool, argmax, x)
unpool = tf.reshape(unpool, unpool_shape)
return unpool
unpool2 = unpool(pool3, argmax3, strides=[1,2,2,1], name='unpool3')
unpool1 = unpool(unpool2, argmax2, strides=[1,2,2,1], name='unpool2')
unpool0 = unpool(unpool1, argmax1, strides=[1,2,2,1], name='unpool1')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
mat_out = mat[:,:,:,0]
pool1_out = sess.run(pool1)[0,:,:,0]
pool2_out = sess.run(pool2)[0,:,:,0]
pool3_out = sess.run(pool3)[0,:,:,0]
argmax1_out = sess.run(argmax1)[0,:,:,0]
argmax2_out = sess.run(argmax2)[0,:,:,0]
argmax3_out = sess.run(argmax3)[0,:,:,0]
unpool2_out = sess.run(unpool2)[0,:,:,0]
unpool1_out = sess.run(unpool1)[0,:,:,0]
unpool0_out = sess.run(unpool0)[0,:,:,0]
print(unpool2_out)
print(unpool1_out)
print(unpool0_out)
输出:
[[ 0. 0.]
[ 0. 63.]]
[[ 0. 0. 0. 0.]
[ 0. 0. 0. 0.]
[ 0. 0. 0. 63.]
[ 0. 0. 0. 0.]]
[[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 63.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.]]