如何在Tensorflow的小批量生产中进行选择性反向传播?

时间:2018-12-11 23:26:27

标签: tensorflow lstm backpropagation mini-batch

最近,我正在做一个项目“ 通过使用Tensorflow中的LSTM从对象的过去轨迹预测对象的未来轨迹。” (在这里,轨迹表示2D位置的序列。)

LSTM的输入当然是“过去的轨迹”,输出是“未来的轨迹”。

训练时小批量的大小是固定的。但是,小批量中的过去轨迹的数量可以不同。例如,让小批量的大小为10。如果我当前的训练迭代只有4条过去的轨迹,则小批10中的6分将填充零值。

在计算反向传播的损失时,我将6个点的损失设为零,以便仅有4个对反向传播做出贡献。

我关心的问题是..即使Tensorflow的损失为零,似乎Tensorflow仍会为6计算梯度。结果,即使我使用相同的训练数据,随着我增加小批量的大小,训练速度也会变慢。

在计算损失时,我还使用了tf.where函数。但是,训练时间不会减少。

如何减少培训时间?

我在这里附上我的伪代码进行培训。

# For each frame in a sequence
for f in range(pred_length):

    # For each element in a batch
    for b in range(batch_size):


        with tf.variable_scope("rnnlm") as scope:
            if (f > 0 or b > 0):
                scope.reuse_variables()

            # for each pedestrian in an element
            for p in range(MNP):

                # ground-truth position
                cur_gt_pose = ...

                # loss mask
                loss_mask_ped = ... # '1' or '0'

                # go through RNN decoder
                output_states_dec_list[b][p], zero_states_dec_list[b][p] = cell_dec(cur_embed_frm_dec,
                                                                                    zero_states_dec_list[b][p])

                # fully connected layer for output
                cur_pred_pose_dec = tf.nn.xw_plus_b(output_states_dec_list[b][p], output_wd, output_bd)

                # go through embedding function for the next input
                prev_embed_frms_dec_list[b][p] = tf.reshape(tf.nn.relu(tf.nn.xw_plus_b(cur_pred_pose_dec, embedding_wd, embedding_bd)), shape=(1, rnn_size))

                # calculate MSE loss
                mse_loss = tf.reduce_sum(tf.pow(tf.subtract(cur_pred_pose_dec, cur_gt_pose_dec), 2.0))

                # only valid ped's traj contributes to the loss
                self.loss += tf.multiply(mse_loss, loss_mask_ped)

1 个答案:

答案 0 :(得分:0)

我认为您正在寻找函数tf.stop_gradient。使用此方法,假设尺寸正确,您可以执行类似tf.where(loss_mask, tensor, tf.stop_gradient(tensor))的操作以获得所需的结果。

但是,看起来这可能不是您的问题。似乎对于数据集中的每个项目,您都在定义新的图节点。这不是TensorFlow应该发挥作用的方式,无论批次大小如何,您都只应预先构建一个执行某些固定功能的图形。绝对不应该为批处理中的每个元素定义新节点,因为这样不能有效利用并行性。