这是tf.scatter_nd_update的正确渐变实现吗?

时间:2018-11-11 04:22:30

标签: python python-3.x tensorflow

不幸的是,Tensorflow不为tf.scatter_nd_update提供任何梯度支持,并且在向后传递中,梯度在那里停止。本质上,此函数只是跨多个数组的一系列分配操作,因此,在每个分配操作中,右侧的梯度应仅传播到左侧。

我已经为tf.scatter_nd_update实现了自己的渐变,但是我不确定它是否正确,因为我不得不将更新和索引的渐变设置为零,因为我无法选择它们。这是我的实现:

import tensorflow as tf
import numpy as np

def reset_graph(seed=4):
    tf.reset_default_graph()
    tf.set_random_seed(seed)
    np.random.seed(seed)

@tf.custom_gradient
def scatter_nd_w_gradient(phi,indices,update):
    phi = tf.scatter_nd_update(phi,indices,update) 

    def grad(dy):
        dz= tf.zeros([2,4], dtype='float32')
        dt= tf.zeros([2], dtype='float32')
        return [dy,dz,dt]

    return phi, grad

def some_operation(x):

    phi = tf.Variable(tf.zeros([1,10,10,1], dtype='float32'), dtype='float32', trainable=True)
    phi_prime= tf.zeros([1,10,10,1], dtype='float32')
    phi=  tf.assign(phi,tf.cast(phi_prime, dtype='float32'))

    ind_y=tf.constant([0,1,3,0])
    ind_x=tf.constant([0,2,1,0])

    indices=ind_y,ind_x
    update=tf.stack([x[0,4,4,0],x[0,4,3,0]])

    phi = scatter_nd_w_gradient(phi,indices,update)
    c3=tf.nn.sigmoid(phi)
    c4=tf.reduce_mean(c3)

    return 1-c4

reset_graph()
a = np.ones((10,10), dtype=np.float32)
k = np.array([[1,1,1],[1,1,1],[1,1,1]],dtype=np.float32)
flip = [slice(None, None, -1), slice(None, None, -1)]
k = k[flip]

a=a.astype(np.float32)
a_tensor = tf.reshape(a, [1, 10, 10, 1])
k_weight = tf.reshape(np.array(k), [3,3,1,1])

c2=tf.layers.conv2d(a_tensor,filters=1, kernel_size=3, strides=1, padding="same",activation = tf.nn.relu)

total_loss2=some_operation(c2)
train_op = tf.train.AdamOptimizer(1e-3).minimize(total_loss2,colocate_gradients_with_ops=True)

init = tf.initialize_all_variables()
sess=tf.Session()
with tf.Session() as sess:
    init = tf.initialize_all_variables()
    sess.run(init)
    _,c2=sess.run([train_op,c2])
    print('this is the value for c2 {}'.format(c2))

代码听起来可能很复杂,但事实并非如此。我只是做一个简单的卷积,然后执行一些操作(可能并不真正有意义,但是为了展示这个概念),并为变量分配一个值,并使用此tf.scatter_nd_update功能从卷积层的输出进行更新。

如果我的实现是正确的,它将帮助很多人尝试使用神经网络的输出并将其与另一个单元组合。请让我知道这是否对您有意义。

0 个答案:

没有答案