我正在尝试在numpy
中实现以下tensorflow
代码。
x = np.array([1,1,1,1,1,2,2,2,3,3,3])
x_u = np.unique(x)
r_indices = []
for v in x_u:
v_indices = np.argwhere(x==v).flatten()
selected_indices = np.random.choice(v_indices, size=int(0.5 * v_indices.shape[0]), replace=False)
r_indices.append(selected_indices)
r_indices = np.concatenate(r_indices)
print(r_indices)
代码输出4个元素的数组。基本上,它从每个唯一集合中随机选择50%
(集合仅包含1
,2
或3
,然后返回相应的索引。
我尝试使用tf.map_fn
,但是它不起作用,因为此函数输出一致的张量。我也尝试过使用tf.while_loop
,但没有成功。
indicators = tf.constant([1,1,1,1,1,2,2,2,3,3,3])
u_indicators = tf.unique(indicators)
ta = tf.Variable([], dtype=tf.int32)
num_elements = tf.shape(u_indicators)[0]
def body_func(i, ta):
v = tf.gather(u_indicators, i)
indices = tf.where(tf.equal(v, indicators))
indices = tf.squeeze(indices) #1D tensor
idxs = tf.range(tf.shape(indices)[0])
num_selected = tf.cast(tf.cast(tf.shape(indices)[0], tf.float32) * 0.5, tf.int32)
ridxs = tf.random_shuffle(idxs)[:num_selected]
ta = tf.concat([ta, ridxs], axis=0)
return (i+1, ta)
i = tf.constant(0)
init_state = (i, ta)
condition = lambda i, _: i < num_elements
n, ta_final = tf.while_loop(condition, body_func, init_state, shape_invariants=[i.get_shape(), tf.TensorShape([None])])
# get the final result
ta_final_result = ta_final.stack()
# run the graph
with tf.Session() as sess:
# print the output of ta_final_result
print(sess.run(ta_final_result))
上面的代码抛出错误:
Traceback (most recent call last):
File "test.py", line 23, in <module>
n, ta_final = tf.while_loop(condition, body_func, init_state, shape_invariants=[i.get_shape(), tf.TensorShape([None])])
File "/Users/tiendh/deeplearning/lib/python3.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3484, in while_loop
loop_vars, shape_invariants, expand_composites=False)
File "/Users/tiendh/deeplearning/lib/python3.7/site-packages/tensorflow/python/util/nest.py", line 304, in assert_same_structure
% (str(e), str1, str2))
TypeError: The two structures don't have the same nested structure.
First structure: type=tuple str=(<tf.Tensor 'Const_1:0' shape=() dtype=int32>, <tf.Variable 'Variable:0' shape=(0,) dtype=int32_ref>)
Second structure: type=list str=[TensorShape([]), TensorShape([Dimension(None)])]
有什么建议吗?
答案 0 :(得分:0)
最后,我修复了代码。解决方法如下:
init