我尝试在tensorflow上运行打击代码:
import tensorflow as tf
import numpy as np
def nuclear_norm(inputs,lamda,beta):
sigma,U,V=tf.svd(inputs)
rank=tf.count_nonzero(sigma)
svp=tf.reduce_sum(tf.cast(sigma>2*(lamda/beta), tf.int32))
def case1(sigma,lamda=1e-2):
svp1=svp
sigma1=sigma[0:svp]-2*tf.to_double(tf.constant(lamda))
return svp1,sigma1
def case2(sigma,lamda=1e-2):
svp1=tf.constant(1)
sigma1=tf.to_double(tf.constant(0))
return svp1,sigma1
svp,sigma=tf.cond(svp<tf.constant(1),lambda:case2(sigma,lamda),lambda:case1(sigma,lamda))
return rank,tf.matmul(tf.matmul(U[:,0:svp],tf.diag(sigma)),tf.transpose(V[:,0:svp]))
x=np.random.randn(2,3)
xs=tf.placeholder(tf.float64,[2,3])
lamda=1
beta=1
init = tf.global_variables_initializer()
sess = tf.Session()
a,b=nuclear_norm(xs,lamda,beta)
b=sess.run(b,feed_dict={xs:x})
当我运行代码时,IDLE显示:InvalidArgumentError:输入必须至少为等级1,得到0。详细信息为打击:
Traceback (most recent call last):
File "C:\Users\yql\Desktop\3.py", line 31, in <module>
sess.run(b,feed_dict={xs:x})
File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\client\session.py", line 905, in run
run_metadata_ptr)
File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1140, in _run
feed_dict_tensor, options, run_metadata)
File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1321, in _do_run
run_metadata)
File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1340, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input must be at least rank 1, got 0
[[Node: Diag = Diag[T=DT_DOUBLE, _device="/job:localhost/replica:0/task:0/device:CPU:0"](cond/Merge_1)]]
Caused by op 'Diag', defined at:
File "<string>", line 1, in <module>
File "C:\Program Files\Python36\lib\idlelib\run.py", line 144, in main
ret = method(*args, **kwargs)
File "C:\Program Files\Python36\lib\idlelib\run.py", line 474, in runcode
exec(code, self.locals)
File "C:\Users\yql\Desktop\3.py", line 29, in <module>
a,b=nuclear_norm(xs,lamda,beta)
File "C:\Users\yql\Desktop\3.py", line 20, in nuclear_norm
return rank,tf.matmul(tf.matmul(U[:,0:svp],tf.diag(sigma)),tf.transpose(V[:,0:svp]))
File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\ops\gen_array_ops.py", line 1882, in diag
"Diag", diagonal=diagonal, name=name)
File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\framework\ops.py", line 3290, in create_op
op_def=op_def)
File "C:\Program Files\Python36\lib\site-packages\tensorflow\python\framework\ops.py", line 1654, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
InvalidArgumentError (see above for traceback): Input must be at least rank 1, got 0
[[Node: Diag = Diag[T=DT_DOUBLE, _device="/job:localhost/replica:0/task:0/device:CPU:0"](cond/Merge_1)]]
有时可以运行,有时不能运行。怎么了?我该如何解决错误?
答案 0 :(得分:0)
问题出在
tf.diag(sigma)
在线段中。
return rank, tf.matmul(tf.matmul(U[:, 0:svp], tf.diag(sigma)), tf.transpose(V[:, 0:svp]))
您正在传递 sigma ,该方法的 tf.dig 等级为0。
该错误易于理解,
将一组对角线元素传递给tf.diag方法时,它将返回具有给定对角线值的对角张量。
例如:
# 'diagonal' is [1, 2, 3, 4]
tf.diag(diagonal) ==> [[1, 0, 0, 0]
[0, 2, 0, 0]
[0, 0, 3, 0]
[0, 0, 0, 4]]
但是,由于仅传递单个元素,例如仅传递 1 ,因此 tf.diag 方法无法返回对角张量。
希望这会有所帮助。