根据相似的值减少一维张量

时间:2019-09-23 22:07:49

标签: tensorflow

我正在尝试在numpy中实现以下tensorflow代码。

x = np.array([1,1,1,1,1,2,2,2,3,3,3])
x_u = np.unique(x)
r_indices = []
for v in x_u:
    v_indices = np.argwhere(x==v).flatten()
    selected_indices = np.random.choice(v_indices, size=int(0.5 * v_indices.shape[0]), replace=False)
    r_indices.append(selected_indices)
r_indices = np.concatenate(r_indices)
print(r_indices)

代码输出4个元素的数组。基本上,它从每个唯一集合中随机选择50%(集合仅包含123,然后返回相应的索引。 我尝试使用tf.map_fn,但是它不起作用,因为此函数输出一致的张量。我也尝试过使用tf.while_loop,但没有成功。

indicators = tf.constant([1,1,1,1,1,2,2,2,3,3,3])
u_indicators = tf.unique(indicators)
ta = tf.Variable([], dtype=tf.int32)
num_elements = tf.shape(u_indicators)[0]

def body_func(i, ta):
    v = tf.gather(u_indicators, i)
    indices = tf.where(tf.equal(v, indicators))
    indices = tf.squeeze(indices) #1D tensor
    idxs = tf.range(tf.shape(indices)[0])
    num_selected = tf.cast(tf.cast(tf.shape(indices)[0], tf.float32) * 0.5, tf.int32)
    ridxs = tf.random_shuffle(idxs)[:num_selected]
    ta = tf.concat([ta, ridxs], axis=0)
    return (i+1, ta)

i = tf.constant(0)
init_state = (i, ta)
condition = lambda i, _: i < num_elements
n, ta_final = tf.while_loop(condition, body_func, init_state, shape_invariants=[i.get_shape(), tf.TensorShape([None])])
# get the final result
ta_final_result = ta_final.stack()

# run the graph
with tf.Session() as sess:
    # print the output of ta_final_result
    print(sess.run(ta_final_result))

上面的代码抛出错误:

Traceback (most recent call last):
  File "test.py", line 23, in <module>
    n, ta_final = tf.while_loop(condition, body_func, init_state, shape_invariants=[i.get_shape(), tf.TensorShape([None])])
  File "/Users/tiendh/deeplearning/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3484, in while_loop
    loop_vars, shape_invariants, expand_composites=False)
  File "/Users/tiendh/deeplearning/lib/python3.7/site-packages/tensorflow/python/util/nest.py", line 304, in assert_same_structure
    % (str(e), str1, str2))
TypeError: The two structures don't have the same nested structure.

First structure: type=tuple str=(<tf.Tensor 'Const_1:0' shape=() dtype=int32>, <tf.Variable 'Variable:0' shape=(0,) dtype=int32_ref>)

Second structure: type=list str=[TensorShape([]), TensorShape([Dimension(None)])]

有什么建议吗?

1 个答案:

答案 0 :(得分:0)

最后,我修复了代码。解决方法如下:

init