Tensorflow实现了一个简单的逻辑

时间:2018-03-07 14:28:50

标签: python tensorflow

我是Tensorflow的新手。我想在Tensorflow中实现一个算法,它具有下面描述的简单逻辑的一部分。

我有一个矩阵(矩阵的大小可能与批量大小不同)。我需要用零替换矩阵中的所有值,除了每行中的最大值。

我想在Tensorflow API中实现这个简单的逻辑。

例如:

Input:
[[0.50041455 0.41183667 0.37627002]
 [0.57736448 0.90280652 0.70880312]
 [0.50961863 0.94126878 0.86982843]
 [0.30285231 0.6302716  0.76009756]]
Output:
[[0.50041455 0.         0.        ]
 [0.         0.9028065  0.        ]
 [0.         0.9412688  0.        ]
 [0.         0.         0.76009756]]

这是我的代码。但我不想使用占位符“Y”并运行tenserflow 2次。 (我担心GPU中的性能)

from __future__ import print_function
import tensorflow as tf
import numpy as np

num_classes = 3
batch_size = 4

X = tf.placeholder("float", [None, num_classes]) #My matrix

values, indices = tf.nn.top_k(X, 1)
Y = tf.placeholder("float", [None, num_classes])

M = X * Y

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    x = np.random.rand(batch_size, num_classes)
    print(x)
    ind = sess.run([indices], feed_dict={X: x})
    y = np.zeros(shape=[batch_size, num_classes], dtype=float)
    for i in range(batch_size):
        y[i, ind[0][i][0]]= 1.0
    m = sess .run(M, feed_dict={X: x, Y: y})
    print(m)

有什么想法吗?

1 个答案:

答案 0 :(得分:0)

我认为这就是你想要的。

理想情况下,我们希望将行中最大元素的位置标记为1,将其他元素标记为0.然后我们只需要将这两个矩阵元素相乘。

为了获得大元素的位置,我们可以使用tf.argmax来帮助我们找到相应轴中最大元素的索引,并tf.one_hot将上面的结果转换为与初始输入矩阵。有关tf.one_hot的详细行为,您可能希望阅读 https://www.tensorflow.org/api_docs/python/tf/one_hot

示例代码如下:

import tensorflow as tf

num_classes = 3

inputs = tf.placeholder(tf.float32, (None, num_classes))

# Calculate the max element in each row, -1 means the last axis
max_in_row = tf.argmax(inputs, axis=-1)

output = tf.one_hot(max_in_row, num_classes) * inputs

with tf.Session() as sess:
    print (
        sess.run(output, feed_dict={
            inputs:
            [[0.50041455, 0.41183667, 0.37627002],
             [0.57736448, 0.90280652, 0.70880312],
             [0.50961863, 0.94126878, 0.86982843],
             [0.30285231, 0.6302716,  0.76009756]]
        })
    )

输出:

[[0.50041455 0.         0.        ]
 [0.         0.9028065  0.        ]
 [0.         0.9412688  0.        ]
 [0.         0.         0.76009756]]