我想遍历一个包含Int
列表的张量,并将一个函数应用于每个元素。
在函数中,每个元素都将从python的字典中获取值。
我尝试过使用tf.map_fn
的简单方法,该方法可以在add
函数上工作,例如以下代码:
import tensorflow as tf
def trans_1(x):
return x+10
a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_1, a)
with tf.Session() as sess:
res = sess.run(b)
print(str(res))
# output: [11 12 13]
但是以下代码引发了KeyError: tf.Tensor'map_8/while/TensorArrayReadV3:0' shape=() dtype=int32
异常:
import tensorflow as tf
kv_dict = {1:11, 2:12, 3:13}
def trans_2(x):
return kv_dict[x]
a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_2, a)
with tf.Session() as sess:
res = sess.run(b)
print(str(res))
我的张量流版本为1.13.1
。谢谢你。
答案 0 :(得分:0)
您不能使用这样的函数,因为参数x
是TensorFlow张量,而不是Python值。因此,为了使它起作用,您还必须将字典转换为张量,但这并不是那么简单,因为字典中的键可能不是顺序的。
您可以代替映射来解决此问题,而对NumPy进行类似于what is proposed here的操作。在TensorFlow中,您可以像这样实现它:
import tensorflow as tf
def replace_by_dict(x, d):
# Get keys and values from dictionary
keys, values = zip(*d.items())
keys = tf.constant(keys, x.dtype)
values = tf.constant(values, x.dtype)
# Make a sequence for the range of values in the input
v_min = tf.reduce_min(x)
v_max = tf.reduce_max(x)
r = tf.range(v_min, v_max + 1)
r_shape = tf.shape(r)
# Mask replacements that are out of the input range
mask = (keys >= v_min) & (keys <= v_max)
keys = tf.boolean_mask(keys, mask)
values = tf.boolean_mask(values, mask)
# Replace values in the sequence with the corresponding replacements
scatter_idx = tf.expand_dims(keys, 1) - v_min
replace_mask = tf.scatter_nd(
scatter_idx, tf.ones_like(values, dtype=tf.bool), r_shape)
replace_values = tf.scatter_nd(scatter_idx, values, r_shape)
replacer = tf.where(replace_mask, replace_values, r)
# Gather the replacement value or the same value if it was not modified
return tf.gather(replacer, x - v_min)
# Test
kv_dict = {1: 11, 2: 12, 3: 13}
with tf.Graph().as_default(), tf.Session() as sess:
a = tf.constant([1, 2, 3])
print(sess.run(replace_by_dict(a, kv_dict)))
# [11, 12, 13]
这将允许您在输入张量中具有值而无需替换(保持原样),也不需要在张量中具有所有替换值。除非输入中的最大值和最小值很远,否则它应该是有效的。
答案 1 :(得分:0)
有一个简单的方法可以实现您正在尝试的目标。
问题在于传递给map_fn
的函数必须将张量作为其参数并将张量作为返回值。但是,您的函数trans_2
将普通的Python int
作为参数,并返回另一个python int
。这就是为什么您的代码不起作用的原因。
但是,TensorFlow提供了一种包装普通python函数的简单方法,即tf.py_func
,您可以按以下方式使用它:
import tensorflow as tf
kv_dict = {1:11, 2:12, 3:13}
def trans_2(x):
return kv_dict[x]
def wrapper(x):
return tf.cast(tf.py_func(trans_2, [x], tf.int64), tf.int32)
a = tf.constant([1, 2, 3])
b = tf.map_fn(wrapper, a)
with tf.Session() as sess:
res = sess.run(b)
print(str(res))
您可以看到我添加了一个包装函数,该函数需要张量参数并返回张量,这就是为什么可以在map_fn中使用它的原因。使用强制转换是因为python默认情况下使用64位整数,而TensorFlow使用32位整数。