如何在张量流中的切片张量上使用tf.clip_by_value()?

时间:2017-10-03 13:39:24

标签: python tensorflow deep-learning

我使用RNN根据过去24小时的湿度和温度值预测下一小时的湿度和温度。为了训练模型,我的输入和输出张量的形状为[24,2],如下所示:

[[23, 78],
 [24, 79],
 [25, 78],
 [23, 81],
  .......
 [27, 82],
 [21, 87],
 [28, 88],
 [23, 90]]

这里我想将湿度列(秒)的值剪切在0到100之间,因为它不能超越它。

我为此目的使用的代码是

.....
outputs[:,1] = tf.clip_by_value(outputs[:,1], 0, 100)
.....

收到以下错误:

'Tensor' object does not support item assignment

将tf.clip_by_value()仅用于一列的正确方法是什么?

3 个答案:

答案 0 :(得分:4)

我认为最简单(但可能不是最优)的方法是使用outputs沿第二维分割tf.split,然后应用裁剪并连接回来(如果需要)。

temperature, humidity = tf.split(output, 2, axis=1)
humidity = tf.clip_by_value(humidity, 0, 100)

# optional concat
clipped_output = tf.concat([temperature, humidity], axis=1)

答案 1 :(得分:1)

如果您的outputs是变量,则可以使用tf.assign

tf.assign(outputs[:,1], tf.clip_by_value(outputs[:,1], 0, 100))
import tensorflow as tf
a = tf.Variable([[23, 78],
 [24, 79],
 [25, 78],
 [23, 81],
 [27, 82],
 [21, 87],
 [28, 88],
 [23, 90]])

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    clipped_value = tf.clip_by_value(a[:,1], 80, 85)
    sess.run(tf.assign(a[:,1], clipped_value))
    print(sess.run(a))

#[[23 80]
# [24 80]
# [25 80]
# [23 81]
# [27 82]
# [21 85]
# [28 85]
# [23 85]]

答案 2 :(得分:0)

以下内容未在手册页https://www.tensorflow.org/api_docs/python/tf/clip_by_value中记录,但在我的测试中似乎有效:clip_by_value应该支持广播。如果是这样,执行此裁剪的最简单方法(如:不创建临时张量)如下:

COALESCE((SELECT TOP (1)
cb.Name
FROM CustomerBilling cb WITH(NOLOCK)
JOIN BillType bt WITH(NOLOCK)
    ON bt.DiagTypeDimID = cb.DiagTypeDimID                                   
WHERE cb.CustomerEventID = d.CustomerEventID --linking this subquery to the main query
 AND bt.DiagType = 'FullBillName'
 ORDER BY cb.ConfirmedDtm DESC), 
(SELECT TOP (1)
cb.Name
FROM CustomerBilling cb WITH(NOLOCK)
JOIN BillType bt WITH(NOLOCK)
    ON bt.DiagTypeDimID = cb.DiagTypeDimID                                   
WHERE cb.CustomerEventID = d.CustomerEventID --linking this subquery to the main query
 AND bt.DiagType = 'QuickBill1'
 ORDER BY cb.ConfirmedDtm DESC), 
(SELECT TOP (1)
cb.Name
FROM CustomerBilling cb WITH(NOLOCK)
JOIN BillType bt WITH(NOLOCK)
    ON bt.DiagTypeDimID = cb.DiagTypeDimID                                   
WHERE cb.CustomerEventID = d.CustomerEventID --linking this subquery to the main query
 AND bt.DiagType = 'QuickBill2'
 ORDER BY cb.ConfirmedDtm DESC), '**Bill Unavilable**')  AS 
"BillName"

这里,我假设您使用的是outputs = tf.clip_by_value(outputs, [[-2147483647, 0]], [[2147483647, 100]]) dtype,因此不想剪切的字段的最小值和最大值。诚然,它并不是超级好,对于浮点数来说看起来更好,在浮点数中可以使用tf.int32-numpy.inf来代替。