使用 GridSearchCV 训练 ANN 时发出警告

时间:2021-07-31 11:30:03

标签: python tensorflow artificial-intelligence

我正在学习构建 ANN,但是当我尝试使用 GridSearchCV 调整参数时出现一些错误:

 def build_classifier(optimizer):
            classifier = Sequential()
            # First layer
            classifier.add(Dense(units = 6, activation = "relu", 
                                 kernel_initializer = "uniform", input_dim = 13))
            #classifier.add(Dropout(rate = 0.1)) # randomly dropout some neurons
            # Second layer
            classifier.add(Dense(units = 6, activation = "relu", 
                                 kernel_initializer = "uniform"))
            #classifier.add(Dropout(rate = 0.1)) # randomly dropout some neurons
            # Last layer
            classifier.add(Dense(units = 1, activation = "sigmoid", 
                                 kernel_initializer = "uniform"))
            classifier.compile(optimizer = "adam", loss = "BinaryCrossentropy", 
                              metrics = "accuracy")
            return classifier
            classifier = KerasClassifier(build_fn = build_classifier)
        parameters = {"batch_size": [25, 32], "epochs": [100, 500], 
                      "optimizer": ["adam", "rmsprop"]}
        grid_search = GridSearchCV(estimator = classifier, param_grid = parameters,
                                  scoring = "accuracy", cv = 10)
        grid_search = grid_search.fit(x_train, y_train)

所以当我运行这个时:

parameters = {"batch_size": [25, 32], "epochs": [100, 500], 
              "optimizer": ["adam", "rmsprop"]}
grid_search = GridSearchCV(estimator = classifier, param_grid = parameters,
                          scoring = "accuracy", cv = 10)
grid_search = grid_search.fit(x_train, y_train)

我遇到了这个警告:

D:\Anaconda\lib\site-packages\tensorflow\python\keras\engine\sequential.py:455: UserWarning: `model.predict_classes()` is deprecated and will be removed after 2021-01-01. Please use instead:* `np.argmax(model.predict(x), axis=-1)`,   if your model does multi-class classification   (e.g. if it uses a `softmax` last-layer activation).* `(model.predict(x) > 0.5).astype("int32")`,   if your model does binary classification   (e.g. if it uses a `sigmoid` last-layer activation).
  warnings.warn('`model.predict_classes()` is deprecated and '

我尝试了很多方法来解决这个问题,但没有奏效。如果有人能对此提供任何解决方案,我将不胜感激!

1 个答案:

答案 0 :(得分:0)

如警告中所述,model.predict_classes() 已弃用,它将在新版本中删除。建议改用 model.predict(x)。详细看一下模型预测 API here

此警告消息不会停止执行,您仍然可以训练模型而不会出现任何错误。您可以通过

抑制警告
import logging, os
logging.disable(logging.WARNING)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf