有没有办法在TensorFlow中提取方阵的对角线?也就是说,对于像这样的矩阵:
[
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]
]
我想要获取元素:[0, 4, 8]
在numpy中,这非常简单,只需np.diag:
在TensorFlow中,有一个diag function,但它只形成一个新的矩阵,其中包含对角线上参数中指定的元素,这不是我想要的。
我可以想象如何通过跨步来实现这一目标...但我不认为在TensorFlow中争取张量。
答案 0 :(得分:9)
张量为0.8,可以用tf.diag_part()
提取对角线元素(见documentation)
<强>更新强>
for tensorflow&gt; = r1.12其tf.linalg.tensor_diag_part
(见documentation)
答案 1 :(得分:3)
目前可以使用tf.diag_part提取对角元素。这是他们的例子:
"""
'input' is [[1, 0, 0, 0],
[0, 2, 0, 0],
[0, 0, 3, 0],
[0, 0, 0, 4]]
"""
tf.diag_part(input) ==> [1, 2, 3, 4]
旧答案(当diag_part时)不可用(如果你想获得现在不可用的东西,仍然相关):
查看math operations和tensor transformations后,看起来不存在此类操作。即使您可以使用矩阵乘法提取此数据,也不会有效(对角线为O(n)
)。
你有三种方法,从易于努力开始。
答案 2 :(得分:3)
使用tf.diag_part()
with tf.Session() as sess:
x = tf.ones(shape=[3, 3])
x_diag = tf.diag_part(x)
print(sess.run(x_diag ))
答案 3 :(得分:2)
这可能是一种解决方法,但有效。
|| handle_error
答案 4 :(得分:0)
使用gather
操作。
x = tensorflow.Variable([[1,2,3],[4,5,6],[7,8,9]])
x_flat = tf.reshape(x, [-1]) # flatten the matrix
x_diag = tf.gather(x, [0, 3, 6])
答案 5 :(得分:0)
根据上下文,掩码可以是“取消”的好方法。关闭矩阵的对角线元素,特别是如果你打算减少矩阵:
Class Dialog_Box
{
JTextField Name; //1st TextField
JTextField Place; //2nd TextField
JTextField City; //3rd TextField
JTextField Country; //4th
.
.
.
JTextField Pincode; //20th TextField
Dialog_Box()
{
JButton Add = new JButton(
{
public void Action Performed()
{
String Name_data=Name.getText();//Getting first input
String Place_data=Place.getText();
.
.
.
String Pincode=Pincode.getText();//Getting 20th input
//Storing these data in Database
}
});
}
}