我的代码:
def f(x):
import tensorflow as tf
print('x.shape={0}'.format(x.shape))
idx = K.cast(x[:, :, 0:2]*15.5+15.5, "int64")
print('idx.shape={0}'.format(idx.shape))
st_z = tf.SparseTensor(idx, values=0.0, dense_shape=[K.shape(x)[0], 32, 32])
输出:
x.shape=(?, 21, 3)
idx.shape=(?, 21, 2)
错误:
ValueError: Shape (?, 21, 2) must have rank 2
如何解决此问题?谢谢