我有一个矩阵:
col_indices =
[[0 1]
[1 2]
[2 3]]
对于每一行,我想使用列索引选择一些元素:
row_indices =
[[0 0]
[1 1]
[2 2]]
在Numpy中,我可以创建行索引:
params[row_indices, col_indices]
并执行tf_params = tf.constant(params)
tf_col_indices = tf.constant(col_indices, dtype=tf.int32)
tf_row_indices = tf.constant(row_indices, dtype=tf.int32)
tf_params[row_indices, col_indices]
在TenforFlow中,我这样做了:
ValueError: Shape must be rank 1 but is rank 3
但是出现了一个错误:
{{1}}
这是什么意思?我该如何正确地进行这种索引?
谢谢!
答案 0 :(得分:0)