如何在Tensorflow中删除

时间:2019-12-23 15:14:30

标签: python tensorflow matrix delete-row

我想从Matrx中删除一行。但是,tensorflow没有模块“删除”,因此,如果有人知道如何做。而且,那么我想在矩阵中添加一行,所以如果有人也知道的话,请。谢谢!

 MatrizDesnormalizada = tf.delete(MatrizDesnormalizada,indice,axis = 0) 

2 个答案:

答案 0 :(得分:0)

您可以使用tf.gather进行所需的操作。以下是一个工作示例。假设您要删除具有0th行数据的张量的3行。

import tensorflow as tf

a = tf.constant([[[1,2,3,4],[4,3,2,1]],[[2,3,4,5],[2,3,4,5]],[[3,4,5,6],[3,4,5,6]]])

del_a = tf.gather(a, [1,2])

with tf.Session() as sess:
  print(sess.run(del_a))

答案 1 :(得分:0)

这是一个可用作矩阵的np.delete的函数

def tf_delete(tensor,index,row=True):
    
    if row:
        sub = list(range(tensor.shape[0]))
    else:
        sub = list(range(tensor.shape[1]))
    sub.pop(index)
    
    if row:
        return  tf.gather(tensor,sub)
    return tf.transpose(tf.gather(tf.transpose(tensor),sub))

,其中行指定您要删除行还是列 例如:

print(tensor)
tensor = tf_delete(tensor,1,row=True)
print(tensor)
tensor = tf_delete(tensor,1,row=False)
print(tensor)

返回

tf.Tensor(
[[9. 5. 1. 3.]
 [9. 9. 6. 8.]
 [9. 9. 9. 2.]
 [9. 9. 9. 9.]], shape=(4, 4), dtype=float32)
tf.Tensor(
[[9. 5. 1. 3.]
 [9. 9. 9. 2.]
 [9. 9. 9. 9.]], shape=(3, 4), dtype=float32)
tf.Tensor(
[[9. 1. 3.]
 [9. 9. 2.]
 [9. 9. 9.]], shape=(3, 3), dtype=float32)