使用稀疏张量的while_loop中的InvalidArgumentError

时间:2019-01-30 15:59:56

标签: python tensorflow

我正在使用while_loop来迭代更新矩阵。循环在密集张量下运行良好,但是当我使用稀疏张量时,出现以下错误:

  

InvalidArgumentError:a_indices的行数不匹配   a_values [[节点:   while / SparseTensorDenseMatMul / SparseTensorDenseMatMul =   SparseTensorDenseMatMul [T = DT_FLOAT,Tindices = DT_INT64,   adjoint_a = false,adjoint_b = false,   _device =“ / job:localhost / replica:0 / task:0 / device:GPU:0”](while / SparseTensorDenseMatMul / SparseTensorDenseMatMul / Enter,   while / SparseTensorDenseMatMul / SparseTensorDenseMatMul / Enter_1,   ConstantFolding / dense_to_sparse / Shape_enter / _1,而/ Switch_1:1)]]
  [[节点:while / Exit_1 / _5 = _Recvclient_terminated = false,   recv_device =“ / job:localhost /副本:0 / task:0 / device:CPU:0”,   send_device =“ / job:localhost /副本:0 / task:0 / device:GPU:0”,   send_device_incarnation = 1,tensor_name =“ edge_62_while / Exit_1”,   tensor_type = DT_FLOAT,   _device =“ / job:localhost /副本:0 /任务:0 /设备:CPU:0”]]

我在两个版本之间进行的唯一更改是使用HH = tf.contrib.layers.dense_to_sparse(HH)转换HH并使用tf.sparse_tensor_dense_matmul(HH,f)而不是tf.matmul(HH,f) -显示在下面的注释代码中。

with tf.device('/gpu:0'):
    g=tf.constant(g,shape=[np.size(g),1],dtype=tf.float32)
    H=tf.constant(H,dtype=tf.float32);
    Ht=tf.transpose(H)
    HH=tf.matmul(Ht,H)
    #HH=tf.contrib.layers.dense_to_sparse(HH)
    a=tf.matmul(Ht,g)
    i=tf.constant(0,dtype=tf.int32)
    f=tf.constant(f,dtype=tf.float32)
    body = lambda i,f:(tf.add(i,1),tf.divide(tf.multiply(f,a),tf.matmul(HH,f)+10e-9))
    #body = lambda i,f:(tf.add(i,1),tf.divide(tf.multiply(f,a),tf.sparse_tensor_dense_matmul(HH,f)+10e-9))
    cond= lambda i,f:tf.less(i,iterations)
    i,f=tf.while_loop(cond,body,(i,f))
sess=tf.Session()
i,f=sess.run([i,f])

请注意,只要H,g和f足够小,此代码就可以工作。例如,对于H.shape =(8000,3840),g.shape =(8000,1),f.shape =(3840,1)和更大的值,会发生此错误,但对于H.shape =(8000, 3584),g.shape =(8000,1),f.shape =(3584,1)和更小。我需要为while循环中的稀疏张量做一些特殊的事情以确保它们保持其形状吗?

1 个答案:

答案 0 :(得分:0)

我尝试从tensorflow 1.8更新到1.12,并且tensorflow完全停止工作(ts.Session会无限期挂起)。因此,我改变了anaconda环境,并从tensorflow 1.12重新开始。在此更新/重新安装后,稀疏张量的问题消失了,尽管尚不清楚该问题是否与tensorflow的版本或我的anaconda环境中的其他问题有关。