如何通过索引切片Tensorflow张量?

时间:2018-08-09 18:34:34

标签: python tensorflow slice

我发现了一个自定义损失函数loss_mse_warmup以用于我的张量流模型:model.compile(optimizer = RMSprop(lr=1e-3), loss = loss_mse_warmup)

功能如下:

def loss_mse_warmup(y_true, y_pred):

    y_true_slice = y_true[:, warmup_steps:,]
    y_pred_slice = y_pred[:, warmup_steps:,]
    loss = tf.losses.mean_squared_error(labels=y_true_slice,predictions=y_pred_slice)
    loss_mean = tf.reduce_mean(loss)

    return loss_mean

但我想通过张量按索引indices=[0,3]进行切片。如果y_true_slice, y_pred_slice是numpy数组,我想这样做:

y_true_slice = y_true[:, warmup_steps:,[0,3]]
y_pred_slice = y_pred[:, warmup_steps:,[0,3]]

我该如何实现?我应该如何更改loss_mse_warmup函数?张量的其他信息:

y_pred.shape = (?, 4)
y_pred.get_shape = <bound method Tensor.get_shape of <tf.Tensor 'dense/BiasAdd:0' shape=(?, 4) dtype=float32>>
tf.shape(y_pred) = Tensor("loss_9/dense_1_loss/Shape:0", shape=(2,), dtype=int32)

0 个答案:

没有答案