如何匹配tf.keras中的logit和标签?

时间:2019-05-04 20:53:11

标签: python tensorflow keras tf.keras

我有一个具有以下架构的预训练模型。

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 200)               0         
_________________________________________________________________
embedding (Embedding)        (None, 200, 300)          402877800 
_________________________________________________________________
spatial_dropout1d (SpatialDr (None, 200, 300)          0         
_________________________________________________________________
bidirectional (Bidirectional (None, 128)               186880    
_________________________________________________________________
dropout (Dropout)            (None, 128)               0         
_________________________________________________________________
batch_normalization_v1 (Batc (None, 128)               512       
_________________________________________________________________
dense (Dense)                (None, 6)                 774       
_________________________________________________________________
reshape (Reshape)            (None, 6)                 0         
=================================================================
Total params: 403,065,966
Trainable params: 187,910
Non-trainable params: 402,878,056
_________________________________________________________________

Reshape层可确保日志与标签相同

我有一个输入数据,它是Tensor形状为(200,)的张量流对象,还有标签,它们也是形状为(6,)的'Tensor`张量流对象。

我的目标在这里是使用tensorflow keras提供的evaluate方法评估一个样本。为了简化起见,我将Tensor对象转换为numpy。由于模型接受[None, 200]的输入形状,因此在将输入数据输入模型之前,我必须重新定形。这样,模型将具有形状为(1,6)的对数,然后通过Reshape层将其重塑为(6,)。

# Clone model is the keras model
# sample_data.x is the input
# sample_data.y is the label

clone_model.evaluate([sample_data.x.numpy().reshape(1,200)], [sample_data.y.numpy()])

但是最后,我收到了以下错误消息

InvalidArgumentError: logits and labels must have the same first dimension, got logits shape [1,6] and labels shape [6]
     [[{{node loss_4/dense_loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits}}]] [Op:StatefulPartitionedCall]

为了克服这个问题,我删除了Reshape层,并将标签的形状重塑为(1,6)(sample_data.y.numpy().reshape(1,6))。但这无济于事,我最终遇到了同样的错误。

我想知道是否有人可以指出我在这里缺少什么?预先感谢。

0 个答案:

没有答案