在TensorFlow

时间:2015-11-13 19:04:26

标签: python tensorflow

有没有办法在TensorFlow中提取方阵的对角线?也就是说,对于像这样的矩阵:

[
 [0, 1, 2],
 [3, 4, 5],
 [6, 7, 8]
]

我想要获取元素:[0, 4, 8]

在numpy中,这非常简单,只需np.diag

在TensorFlow中,有一个diag function,但它只形成一个新的矩阵,其中包含对角线上参数中指定的元素,这不是我想要的。

我可以想象如何通过跨步来实现这一目标...但我不认为在TensorFlow中争取张量。

6 个答案:

答案 0 :(得分:9)

张量为0.8,可以用tf.diag_part()提取对角线元素(见documentation

<强>更新

for tensorflow&gt; = r1.12其tf.linalg.tensor_diag_part(见documentation

答案 1 :(得分:3)

目前可以使用tf.diag_part提取对角元素。这是他们的例子:

"""
'input' is [[1, 0, 0, 0],
            [0, 2, 0, 0],
            [0, 0, 3, 0],
            [0, 0, 0, 4]]
"""

tf.diag_part(input) ==> [1, 2, 3, 4]

旧答案(当diag_part时)不可用(如果你想获得现在不可用的东西,仍然相关):

查看math operationstensor transformations后,看起来不存在此类操作。即使您可以使用矩阵乘法提取此数据,也不会有效(对角线为O(n))。

你有三种方法,从易于努力开始。

  1. 评估张量,用numpy提取对角线,用TF
  2. 构建变量
  3. 以Anurag建议的方式使用tf.pack(也使用tf.shape提取值3
  4. 编写自己的op in C++,重建TF并原生使用。

答案 2 :(得分:3)

使用tf.diag_part()

with tf.Session() as sess:
    x = tf.ones(shape=[3, 3])
    x_diag = tf.diag_part(x)
    print(sess.run(x_diag ))

答案 3 :(得分:2)

这可能是一种解决方法,但有效。

|| handle_error

答案 4 :(得分:0)

使用gather操作。

x = tensorflow.Variable([[1,2,3],[4,5,6],[7,8,9]])
x_flat = tf.reshape(x, [-1])  # flatten the matrix
x_diag = tf.gather(x, [0, 3, 6])

答案 5 :(得分:0)

根据上下文,掩码可以是“取消”的好方法。关闭矩阵的对角线元素,特别是如果你打算减少矩阵:

Class Dialog_Box
    {
    JTextField Name; //1st TextField
    JTextField Place; //2nd TextField
    JTextField City; //3rd TextField
    JTextField Country; //4th
    .
    .
    .
    JTextField Pincode; //20th TextField

    Dialog_Box()
        {
        JButton Add = new JButton(
            {
            public void Action Performed()
                {
                String Name_data=Name.getText();//Getting first input
                String Place_data=Place.getText();
                .
                .
                .
                String Pincode=Pincode.getText();//Getting 20th input
                //Storing these data in Database
                }
            }); 
        }
    }