tf.scan中的索引

时间:2018-10-20 12:38:24

标签: numpy tensorflow vectorization scanning map-function

我需要对下面给出的非常大的张量进行操作np_data。使用静态numpy数组np_functions相应地更改了该张量的每个元素。我需要一种方法来确定np_functions的索引,具体取决于张量run_fn处于打开状态。

import tensorflow as tf
import numpy as np

np_data = np.array([[1,0], 
                    [1,0],
                    [0,1]])

np_functions = np.array([[1,0,1],
                         [2,1,0],
                         [3,0,1]])

def run_fn(_a, _, _data_in):
    INDEX = run_fn_current_index_in_tensor_data_in?!?!?!?
    if np_functions[INDEX][0] == 1:
        return _data_in[0] + _data_in[1]
    if np_functions[INDEX][0] == 2:
        return _data_in[0] - _data_in[1]
    if np_functions[INDEX][0] == 3:
        return _data_in[0] / _data_in[2]


_data_in = tf.placeholder(tf.int32, shape=(3, 2), name='data_in')
data = tf.scan(lambda _a, _: run_fn(_a, _, _data_in), _data_in)
sess = tf.Session()
model =  tf.global_variables_initializer()
sess.run(model)
np_data = sess.run(data, feed_dict={_data_in: np_data})

在np_data中放置索引或将np_functions馈入tf.scan并不是一个好的解决方案,因为np_function变为动态张量,并且每次迭代都会对其进行评估。 np_functions是静态的,并且在每个run_fn单元中,所有迭代都会做出相同的决定。 np_data正在更改每次迭代。

0 个答案:

没有答案