从张量流中的张量中提取特定元素

时间:2017-01-23 10:15:14

标签: python tensorflow

我在python上使用tensorflow 我有一个形状的数据张量[?,5,37]和一个形状的idx张量[?,5]

我想从数据中提取元素并获得形状[?,5]的输出,以便:

output[i][j] = data[i][j][idx[i, j]] for all i in range(?) and j in range(5)

看起来lof tf.gather_nd()函数最接近我的需要,但我不知道如何在我的情况下使用它...

谢谢!

编辑:我设法用gather_nd完成,如下所示,但是有更好的选择吗? (看起来有点笨拙)

    nRows = tf.shape(length_label)[0] ==> ?
    nCols = tf.constant(MAX_LENGTH_INPUT + 1, dtype=tf.int32) ==> 5
    m1 = tf.reshape(tf.tile(tf.range(nCols), [nRows]),
                                           shape=[nRows, nCols])
    m2 = tf.transpose(tf.reshape(tf.tile(tf.range(nRows), [nCols]),
                                            shape=[nCols, nRows]))
    indices = tf.pack([m2, m1, idx], axis=-1)
    # indices should be of shape [?, 5, 3] with indices[i,j]==[i,j,idx[i,j]]
    output = tf.gather_nd(data, indices=indices)

1 个答案:

答案 0 :(得分:1)

我设法使用void Start() { NetworkServer.Listen(9000); NetworkServer.RegisterHandler(MsgType.Connect, OnConnected); NetworkServer.RegisterHandler(MsgType.Disconnect, OnDisconnected); NetworkServer.RegisterHandler(MsgType.Error, OnError); } public void OnConnected(NetworkMessage netMsg) { Debug.Log("Client Connected"); } public void OnDisconnected(NetworkMessage netMsg) { Debug.Log("Disconnected"); } public void OnError(NetworkMessage netMsg) { Debug.Log("Error while connecting"); } 执行此操作,如下所示

gather_nd