Tensorflo- InvalidArgumentError(请参见上面的回溯):输入必须至少为等级1,得到0

时间:2018-09-26 03:08:09

标签: tensorflow

我尝试在tensorflow上运行打击代码:

import tensorflow as tf
import numpy as np

def nuclear_norm(inputs,lamda,beta):           
    sigma,U,V=tf.svd(inputs)
    rank=tf.count_nonzero(sigma)
    svp=tf.reduce_sum(tf.cast(sigma>2*(lamda/beta), tf.int32))  
    def case1(sigma,lamda=1e-2):
        svp1=svp
        sigma1=sigma[0:svp]-2*tf.to_double(tf.constant(lamda))
        return svp1,sigma1

    def case2(sigma,lamda=1e-2):
        svp1=tf.constant(1)
        sigma1=tf.to_double(tf.constant(0))
        return svp1,sigma1

    svp,sigma=tf.cond(svp<tf.constant(1),lambda:case2(sigma,lamda),lambda:case1(sigma,lamda))

    return rank,tf.matmul(tf.matmul(U[:,0:svp],tf.diag(sigma)),tf.transpose(V[:,0:svp]))

x=np.random.randn(2,3)
xs=tf.placeholder(tf.float64,[2,3])
lamda=1
beta=1

init = tf.global_variables_initializer()
sess = tf.Session()
a,b=nuclear_norm(xs,lamda,beta)
b=sess.run(b,feed_dict={xs:x})

当我运行代码时,IDLE显示:InvalidArgumentError:输入必须至少为等级1,得到0。详细信息为打击:

    Traceback (most recent call last):
      File "C:\Users\yql\Desktop\3.py", line 31, in <module>
        sess.run(b,feed_dict={xs:x})
      File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\client\session.py", line 905, in run
        run_metadata_ptr)
      File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1140, in _run
        feed_dict_tensor, options, run_metadata)
      File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1321, in _do_run
        run_metadata)
      File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1340, in _do_call
        raise type(e)(node_def, op, message)
    tensorflow.python.framework.errors_impl.InvalidArgumentError: Input must be at least rank 1, got 0
         [[Node: Diag = Diag[T=DT_DOUBLE, _device="/job:localhost/replica:0/task:0/device:CPU:0"](cond/Merge_1)]]

    Caused by op 'Diag', defined at:
      File "<string>", line 1, in <module>
      File "C:\Program Files\Python36\lib\idlelib\run.py", line 144, in main
        ret = method(*args, **kwargs)
      File "C:\Program Files\Python36\lib\idlelib\run.py", line 474, in runcode
        exec(code, self.locals)
      File "C:\Users\yql\Desktop\3.py", line 29, in <module>
        a,b=nuclear_norm(xs,lamda,beta)
      File "C:\Users\yql\Desktop\3.py", line 20, in nuclear_norm
        return rank,tf.matmul(tf.matmul(U[:,0:svp],tf.diag(sigma)),tf.transpose(V[:,0:svp]))
      File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\ops\gen_array_ops.py", line 1882, in diag
        "Diag", diagonal=diagonal, name=name)
      File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
        op_def=op_def)
      File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\framework\ops.py", line 3290, in create_op
        op_def=op_def)
      File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\framework\ops.py", line 1654, in __init__
        self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

    InvalidArgumentError (see above for traceback): Input must be at least rank 1, got 0
         [[Node: Diag = Diag[T=DT_DOUBLE, _device="/job:localhost/replica:0/task:0/device:CPU:0"](cond/Merge_1)]]

有时可以运行,有时不能运行。怎么了?我该如何解决错误?

1 个答案:

答案 0 :(得分:0)

问题出在

tf.diag(sigma)

在线段中。

return rank, tf.matmul(tf.matmul(U[:, 0:svp], tf.diag(sigma)), tf.transpose(V[:, 0:svp]))

您正在传递 sigma ,该方法的 tf.dig 等级为0。

该错误易于理解,

将一组对角线元素传递给tf.diag方法时,它将返回具有给定对角线值的对角张量。

例如:

# 'diagonal' is [1, 2, 3, 4]
tf.diag(diagonal) ==> [[1, 0, 0, 0]
                       [0, 2, 0, 0]
                       [0, 0, 3, 0]
                       [0, 0, 0, 4]]

但是,由于仅传递单个元素,例如仅传递 1 ,因此 tf.diag 方法无法返回对角张量。

希望这会有所帮助。