深度学习中的StackEnsemble

时间:2020-06-03 09:55:10

标签: python deep-learning stack ensemble-learning

我一直在尝试使用DeepStack库的StackEnsemble。第一个模型是VGG16,第二个模型是VGG19,最后一个模型是具有6个完全连接层的CNN。

# Load Keras Models

model1 = tf.keras.models.load_model('VGG16.h5')

model2 = tf.keras.models.load_model('vgg19.h5')

model3 = tf.keras.models.load_model('basic_cnn.h5')

from deepstack.base import KerasMember

member1 = KerasMember(name="model1", keras_model=model1, train_batches=train_generator, val_batches=validation_generator)
member2 = KerasMember(name="model2", keras_model=model2, train_batches=train_generator, val_batches=validation_generator)
member3 = KerasMember(name="model3", keras_model=model3, train_batches=train_generator, val_batches=validation_generator)


from deepstack.ensemble import StackEnsemble
import sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import ExtraTreesClassifier

#Ensure you have the scikit-learn version >= 0.22 installed
print("sklearn version must be >= 0.22. You have:", sklearn.__version__)

stack = StackEnsemble()

# 2nd Level Meta-Learner
estimators = [
    ('rf', RandomForestClassifier(verbose=0, n_estimators=200, max_depth=15, n_jobs=20, min_samples_split=30)),
    ('etr', ExtraTreesClassifier(verbose=0, n_estimators=200, max_depth=10, n_jobs=20, min_samples_split=20))
]
# 3rd Level Meta-Learner
clf = StackingClassifier(
    estimators=estimators, final_estimator=LogisticRegression()
)

stack.model = clf
stack.add_members([member1, member2, member3])
stack.fit()
stack.describe(metric=sklearn.metrics.accuracy_score)

代码一直运行到这里,但以下是错误提示。

predictions = stack.predict_generator(test_generator)


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-37-5442787a7d9e> in <module>
      1 #Making prediction on test images using predict_trainer
      2 
----> 3 predictions = stack.predict_generator(test_generator)

AttributeError: 'StackEnsemble' object has no attribute 'predict_generator'

1 个答案:

答案 0 :(得分:0)

看源码here -

如果我理解正确,错误似乎只是意味着您定义的堆栈对象没有可调用函数“predict_generator”。相反,使用 predictions = stack.predict(test_generator) 它将调用与您的 CNN 关联的预测函数并将预测概率作为 np.array 返回。