我需要基于这样的输入张量创建一个1和0的张量
input = tf.constant([3, 2, 4, 1, 0])
输出=
0 0 0 0 0
0 0 0 1 0
0 0 1 1 0
0 0 1 1 1
0 1 1 1 1
本质上,输入张量的每个值的索引(i)+1指定了我将在该列中开始放置1s的行。
答案 0 :(得分:1)
这里是TensorFlow操作的实现。有关详细信息,请参见评论。
import tensorflow as tf
input = tf.placeholder(tf.int32, [None])
# Find indices that sort the input
# There is no argsort yet in the stable API,
# but you can do the same with top_k
_, order = tf.nn.top_k(-input, tf.shape(input)[0])
# Or use the implementation in contrib
order = tf.contrib.framework.argsort(input)
# Build triangular lower matrix
idx = tf.range(tf.shape(input)[0])
triangular = idx[:, tf.newaxis] > idx
# Reorder the columns according to the order
result = tf.gather(triangular, order, axis=1)
# Cast result from bool to int or float as needed
result = tf.cast(result, tf.int32)
with tf.Session() as sess:
print(sess.run(result, feed_dict={input: [3, 2, 4, 1, 0]}))
输出:
[[0 0 0 0 0]
[0 0 0 1 0]
[0 0 1 1 0]
[0 0 1 1 1]
[0 1 1 1 1]]
答案 1 :(得分:0)
此代码可以达到预期的效果。但是它不使用矢量化函数来简化此过程。代码中有一些注释。
形状是根据问题假定的。如果更改输入,则需要进行更多测试。
init = tf.constant_initializer(np.zeros((5, 5)))
inputinit = tf.constant([3, 2, 4, 1, 0])
value = tf.gather( inputinit , [0,1,2,3,4])
sess = tf.Session()
#Combine rows to get the final desired tensor
def merge(a) :
for i in range(0, ( value.get_shape()[0] - 1 )) :
compare = tf.to_int32(
tf.not_equal(tf.gather(a, i ),
tf.gather(a, ( i + 1 ))))
a = tf.scatter_update(a, ( i + 1 ), compare)
#Insert zeros in first row and move all other rows down by one position.
#This eliminates the last row which isn't needed
return tf.concat([tf.reshape([0,0,0,0,0],(1,5)),
a[0:1],a[1:2],a[2:3],a[3:4]],axis=0)
# Insert ones by stitching individual tensors together by inserting one in
# the desired position.
def insertones() :
a = tf.get_variable("a", [5, 5], dtype=tf.int32, initializer=init)
sess.run(tf.global_variables_initializer())
for i in range(0, ( value.get_shape()[0] )) :
oldrow = tf.gather(a, i )
index = tf.squeeze( value[i:( i + 1 )] )
begin = oldrow[: index ]
end = oldrow[index : 4]
newrow = tf.concat([begin, tf.constant([1]), end], axis=0)
if( i <= 4 ) :
a = tf.scatter_update(a, i, newrow)
return merge(a)
a = insertones()
print(sess.run(a))
输出是这个。
[[0 0 0 0 0]
[0 0 0 1 0]
[0 0 1 1 0]
[0 0 1 1 1]
[0 1 1 1 1]]