我正在使用Tensorflow.js来预测我在Keras训练的模型。但是,当我输入4维张量时,会出现以下错误:
UnhandledPromiseRejectionWarning: Unhandled promise rejection (rejection id: 1): Error: dot support for x of rank 4 is not yet implemented: x shape = 32,1,1,100
我在网上找不到有关此错误的任何内容-我怀疑它与Tensorflow.js尚不具备此功能有关,但我不确定。知道我可以在哪里获得更多信息吗?
这是我的代码,引发错误的行是model.predict(noise_tensor)
。大部分与之无关的代码:
noise_tensor.print(true)
generated_images = model.predict(noise_tensor) //error occours here
这是我的4d张量的打印输出:
Tensor
dtype: float32
rank: 4
shape: [64,1,1,100]
values:
[ [ [[0.3799773 , -0.0252707, 0.0118336 , ..., 0.1703698 , -0.0649208, 0.2152225 ],]],
[ [[0.219656 , 0.2850143 , -0.1078744, ..., 0.1627689 , -0.0838831, -0.1112608],]],
[ [[-0.1295149, -0.08308 , 0.1872116 , ..., -0.2033772, -0.4184959, -0.3357461],]],
...
[ [[0.0029674 , 0.0422036 , 0.067896 , ..., 0.1368463 , 0.1122015 , -0.0395375],]],
[ [[0.043546 , -0.0281712, 0.0898769 , ..., 0.205565 , 0.1444133 , 0.0067788 ],]],
[ [[-0.1089588, -0.0161969, -0.0724337, ..., 0.1427118 , -0.2577117, 0.0013836 ],]]]
以下是Keras模型的摘要:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 1, 1, 32768) 3309568
_________________________________________________________________
reshape_1 (Reshape) (None, 8, 8, 512) 0
_________________________________________________________________
batch_normalization_1 (Batch (None, 8, 8, 512) 2048
_________________________________________________________________
activation_1 (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 16, 16, 256) 3277056
_________________________________________________________________
batch_normalization_2 (Batch (None, 16, 16, 256) 1024
_________________________________________________________________
activation_2 (Activation) (None, 16, 16, 256) 0
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 32, 32, 128) 819328
_________________________________________________________________
batch_normalization_3 (Batch (None, 32, 32, 128) 512
_________________________________________________________________
activation_3 (Activation) (None, 32, 32, 128) 0
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 64, 64, 64) 204864
_________________________________________________________________
batch_normalization_4 (Batch (None, 64, 64, 64) 256
_________________________________________________________________
activation_4 (Activation) (None, 64, 64, 64) 0
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 128, 128, 1) 1601
_________________________________________________________________
activation_5 (Activation) (None, 128, 128, 1) 0
=================================================================
Total params: 7,616,257
Trainable params: 7,614,337
Non-trainable params: 1,920
_________________________________________________________________
以及Python中的相应代码:
def construct_generator():
generator = Sequential()
generator.add(Dense(units=8 * 8 * 512,
kernel_initializer='glorot_uniform',
input_shape=(1, 1, 100)))
generator.add(Reshape(target_shape=(8, 8, 512)))
generator.add(BatchNormalization(momentum=0.5))
generator.add(Activation('relu'))
generator.add(Conv2DTranspose(filters=256, kernel_size=(5, 5),
strides=(2, 2), padding='same',
data_format='channels_last',
kernel_initializer='glorot_uniform'))
generator.add(BatchNormalization(momentum=0.5))
generator.add(Activation('relu'))
generator.add(Conv2DTranspose(filters=128, kernel_size=(5, 5),
strides=(2, 2), padding='same',
data_format='channels_last',
kernel_initializer='glorot_uniform'))
generator.add(BatchNormalization(momentum=0.5))
generator.add(Activation('relu'))
generator.add(Conv2DTranspose(filters=64, kernel_size=(5, 5),
strides=(2, 2), padding='same',
data_format='channels_last',
kernel_initializer='glorot_uniform'))
generator.add(BatchNormalization(momentum=0.5))
generator.add(Activation('relu'))
generator.add(Conv2DTranspose(filters=1, kernel_size=(5, 5),
strides=(2, 2), padding='same',
data_format='channels_last',
kernel_initializer='glorot_uniform'))
generator.add(Activation('tanh'))
optimizer = Adam(lr=0.00015, beta_1=0.5)
generator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=None)
print('generator')
generator.summary()
return generator
这是tensorflow.js中的错误。对于未来的访问者,请查看GitHub线程here。
答案 0 :(得分:1)
目前,输入tf.dot
应该处于1或2级