忽略张量流扫描中的小变化

时间:2018-04-18 01:23:12

标签: tensorflow

我想比较扫描中的当前值和先前值,并修改当前值,以便忽略相对于之前小于阈值的更改。另外,如何计算更改大于阈值的数字元素?

这是我的尝试

def f(prev_y, curr_y):
    thrshd= 0.05
    diff = curr_y - prev_y
    pos_dif = tf.nn.relu(diff-thrshd)
    neg_dif = tf.nn.relu(-diff+thrshd)
    fval = prev_y + pos_dif - neg_dif
    return fval

a = tf.constant([[.1, .26, .3, .2, .15],
                 [.07, .35, .24, .23, .19]])

init = tf.constant([0.2, 0.2, 0.2, 0.2, 0.2])
c = tf.scan(f, a, initializer=init)

with tf.Session() as sess:
  print(sess.run(c))

输出:

[[0.05       0.21       0.25       0.15       0.10000001]
 [0.02       0.3        0.19       0.18       0.14      ]]

期望的输出:

[[0.1        0.26       0.30       0.20       0.2]
 [0.1        0.35       0.24       0.20       0.2  ]]

超过阈值的变化次数:3,2

1 个答案:

答案 0 :(得分:0)

我使用tf.where

import tensorflow as tf

def f(prev_y, curr_y):
    thrshd= 0.05
    return tf.where(tf.greater(tf.abs(curr_y - prev_y), thrshd),
                    curr_y,
                    prev_y)

a = tf.constant([[.1, .26, .3, .2, .15],
                 [.07, .35, .24, .23, .19]])

init = tf.constant([0.2, 0.2, 0.2, 0.2, 0.2])
c = tf.scan(f, a, initializer=init)

with tf.Session() as sess:
  print(sess.run(c)

打印:

[[0.1  0.26 0.3  0.2  0.2 ]
 [0.1  0.35 0.24 0.2  0.2 ]]

还要计算超过阈值的值:

import tensorflow as tf

def f(accumulator, curr_y):
    prev_y, _ = accumulator
    thrshd= 0.05
    greater = tf.greater(tf.abs(curr_y - prev_y), thrshd)
    new_array = tf.where(greater, curr_y, prev_y)
    return (new_array, tf.count_nonzero(greater))

a = tf.constant([[.1, .26, .3, .2, .15],
                 [.07, .35, .24, .23, .19]])

init = tf.constant([0.2, 0.2, 0.2, 0.2, 0.2])
output_array, output_count  = tf.scan(
    f, a, initializer=(init, tf.zeros([], dtype=tf.int64)))
with tf.Session() as sess:
  print(sess.run((output_array, output_count)))

打印:

(array([[0.1 , 0.26, 0.3 , 0.2 , 0.2 ],
       [0.1 , 0.35, 0.24, 0.2 , 0.2 ]], dtype=float32), array([3, 2]))

含义3个值在第一次迭代时高于阈值,2为第二次迭代。