计算每个minibatch实例的渐变

时间:2017-10-23 11:48:19

标签: python python-3.x tensorflow

我想计算每个实例的小批量的渐变。最终目标是根据minib_i中的input_i以不同方式对渐变进行加权。但是,map_fn当前给出了以下错误:

ValueError: The two structures don't have the same number of elements.

First structure (1 elements): <dtype: 'float32'>

Second structure (2 elements): [<tf.Tensor 'map/while/gradients/MatMul_grad/MatMul_1:0' shape=(4, 1) dtype=float32>, <tf.Tensor 'map/while/gradients/add_grad/Reshape_1:0' shape=(1, 1) dtype=float32>]

我认为错误是b / c,compute_grad_i的响应与loss_i输入的形状不同。以下是重现错误的代码:

import tensorflow as tf
import numpy as np

x = tf.placeholder(dtype=tf.float32, shape=[None, 4])
y = tf.placeholder(dtype=tf.float32, shape=[None, 1])

W = tf.get_variable('w', shape=[4, 1])
b = tf.get_variable('b', shape=[1, 1])

y_pred = tf.matmul(x, W) + b

loss = (y_pred - y) ** 2

trainable_vars = tf.trainable_variables()
compute_grad_i = lambda loss_i: tf.gradients(loss_i, trainable_vars)
grads = tf.map_fn(compute_grad_i, loss)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
feed_dict = {x: np.random.randn(3, 4), y: np.random.randn(3, 1)}
g = sess.run([grads], feed_dict=feed_dict)

1 个答案:

答案 0 :(得分:0)

根据this issue

计算map_fn中的渐变是一个已知问题w / tensorflow

这是一种对每批渐变进行加权的丑陋方式

import tensorflow as tf
import numpy as np

x = tf.placeholder(dtype=tf.float32, shape=[None, 4])
y = tf.placeholder(dtype=tf.float32, shape=[None, 1])

W = tf.get_variable('w', shape=[4, 1])
b = tf.get_variable('b', shape=[1, 1])

y_pred = tf.matmul(x, W) + b

loss = (y_pred - y) ** 2

trainable_vars = tf.trainable_variables()

g = tf.gradients(loss, trainable_vars)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

x_in = np.random.randn(3, 4)
y_in = np.random.randn(3, 1)

gs = None
mini_i_w = [1,2,3]
for i, [x_i, y_i] in enumerate(zip(x_in, y_in)):
    feed_dict = {x: np.reshape(x_i, [-1, x_in.shape[1]]), y: np.reshape(y_i, [-1, y_in.shape[1]])}
    g_i = sess.run(g, feed_dict=feed_dict)

    if gs is None:
        gs = mini_i_w[i] * g_i
    else:
        gs = [gs[k] + mini_i_w[i] * g_i[k] for k in range(len(g_i))]