如何按恒定值缩放Tensor列?

时间:2019-05-26 19:55:49

标签: tensorflow keras

我需要按一个恒定值缩放张量的某些列,但是我不知道如何使用Keras / Tensorflow来解决这个问题。我有一个(BatchSize,6)矩阵,我需要将第2列乘以一个常数,而将5列乘以另一个常数。

我尝试制作一个Lambda,该Lambda使用切片索引将列相乘,但TF返回有关无法为结果分配值的错误。

例如

x[:,2] *= constant

有什么建议吗?

1 个答案:

答案 0 :(得分:1)

只需在所需的列位置乘以一个张量和常数即可。例如:

import tensorflow as tf
import numpy as np

x = tf.placeholder(tf.float32, (None, 6))

const1 = 5.
const2 = 3.
scaler = tf.constant([1, 1, const1, 1, 1, const2], dtype=tf.float32)
res = x*scaler

x_data = np.ones((3, 6))
with tf.Session() as sess:
    print(res.eval({x:x_data}))
# [[1. 1. 5. 1. 1. 3.]
#  [1. 1. 5. 1. 1. 3.]
#  [1. 1. 5. 1. 1. 3.]]