TensorFlow不在所有输入上进行训练

时间:2020-03-11 05:15:30

标签: tensorflow

我正在尝试在5774上训练TF模型。但是它卡在了96个示例中,justs跳到了下一个时代,忽略了大多数示例。 TF为什么显示这种行为以及如何解决?

users/bookings/4

输出:

model.compile(
    optimizer='rmsprop',
    loss='categorical_crossentropy',
    metrics=['acc']
    )

callback = tf.keras.callbacks.EarlyStopping(monitor='acc', patience=50)
history = model.fit(
    x=[train_ids, train_masks, train_segments],
    y=train_y,
    batch_size=32,
    epochs=10000,
    verbose=1,
    callbacks=[callback]
    )

1 个答案:

答案 0 :(得分:1)

在我的情况下,train_ids,train_masks和train_segments是n个np.array的列表,其形状为(96,)。在用steps_per_epoch = 5774 // 32强制拟合之后,它显示了正确的消息错误:尽管日志中显示的是5774,但输入仅具有96个样本。

将列表投射到np.array可以解决问题,尽管我认为tensorflow日志中仍然存在错误。