假设我有一个ts
形状[s1, s2, s3]
的张量,我想用tf.map_fn
迭代它:
tf.map_fn(lambda dim1:
tf.map_fn(lambda dim2:
do_sth(dim, idx1, idx2)
,dim)
,ts)
上面的idx1
和idx2
是ts
当前所在的{0}的维度0和维度1的索引。我怎样才能获得这些?我希望得到这样的东西,就像我做的那样:
do_sth()
我不能这样做的原因是大部分时间for idx1 in range(s1):
for idx2 in range(s2):
tensor = ts[idx1][idx2]
do_sth(tensor, idx1, idx2)
未知(即s1, s2, s3
形状ts
或类似)
这可能吗?
答案 0 :(得分:1)
我建议你在最后添加一个额外的维度,并在运行命令之前用索引填充它。
此代码(已测试)将索引添加到X和Y索引的值乘以10和100:
from __future__ import print_function
import tensorflow as tf
import numpy as np
r = 3
a = tf.reshape( tf.constant( range( r * r ) ), ( r, r ) )
x = tf.tile( tf.cast( tf.lin_space( 0.0, r - 1, r )[ None, : ], tf.int32 ), [ r, 1 ] )
y = tf.tile( tf.cast( tf.lin_space( 0.0, r - 1, r )[ :, None ], tf.int32 ), [ 1, r ] )
b = tf.stack( [ a, x, y ], axis = -1 )
c = tf.map_fn( lambda y: tf.map_fn( lambda x:
x[ 0 ] + 10 * x[ 1 ] + 100 * x[ 2 ]
, y ), b )
with tf.Session() as sess:
res = sess.run( [ c ] )
for x in res:
print()
print( x )
输出:
[[0 11 22]
[103 114 125]
[206 217 228]]