TensorFlow / Keras中的tf.scatter_nd中的尺寸错误

时间:2018-12-10 05:12:19

标签: python tensorflow keras tensor keras-2

我的代码:

reshape_out = Reshape((21, 3), input_shape=(21*3,), name='reshape_to_21_3')(output3d)

def proj_output_shape(shp):
    return (None, 32, 32, 1)

def f(x):
    import tensorflow as tf
    batch_size = K.shape(x)[0]
    print('x.shape={0}'.format(x.shape))

    idx = K.cast(x[:, :, 0:2]*15.5+15.5, "int32")
    print('idx.shape={0}'.format(idx.shape))

    # z = mysparse_to_dense(idx, (K.shape(x)[0], 32, 32), 1.0, 0.0, name='sparse_tensor')
    updates = tf.ones([batch_size, 21])
    print('updates.shape={0}'.format(updates.shape))

    #shape = tf.Variable(np.array([batch_size, 32, 32]))
    #print('shape.shape={0}'.format(shape))

    z = tf.scatter_nd(indices=idx,
                      updates=updates,
                      shape=(batch_size, 32, 32),
                      name='cool')

    print('z={0}'.format(z))
    #z = tf.add(z, z)
    #z = tf.sparse_add(tf.zeros(z.dense_shape), z)
    z = K.reshape(z, (K.shape(x)[0], 32, 32, 1))
    print('z.shape={0}'.format(z.shape), z)

    fil = make_kernel(1.0)
    fil = K.reshape(fil, (5, 5, 1, 1))
    print('fil.shape={0}'.format(fil.shape), fil)

    r = K.conv2d(z,kernel=fil, padding='same', data_format="channels_last")
    print('r.shape={0}'.format(r.shape), r)

    return r

输出:

x.shape=(?, 21, 3)
idx.shape=(?, 21, 2)
updates.shape=(?, 21)

错误:

ValueError: The inner 1 dimensions of output.shape=[?,?,?] must match the inner 0 dimensions of updates.shape=[?,21]: Shapes must be equal rank, but are 1 and 0 for 'projection_4/cool' (op: 'ScatterNd') with input shapes: [?,21,2], [?,21], [3].

如何解决此问题?谢谢

0 个答案:

没有答案