Keras,TensorFlow和scikit可能的兼容性问题(tf.global_variables())

时间:2017-03-27 19:23:58

标签: tensorflow scikit-learn regression keras cross-validation

我尝试使用Keras Regressor上的数据集进行小型测试(使用TensorFlow),但我遇到了一个小问题。该错误似乎是来自scikit的函数cross_val_score。它从它开始,最后一条错误消息是:

File "/usr/local/lib/python2.7/dist-packages/Keras-2.0.2-py2.7.egg/keras/backend/tensorflow_backend.py", line 298, in _initialize_variables
variables = tf.global_variables()
AttributeError: 'module' object has no attribute 'global_variables'

我的完整代码基本上是在http://machinelearningmastery.com/regression-tutorial-keras-deep-learning-library-python/中找到的带有小变化的示例。 我已经看过" '模块'对象没有属性" global_variables' "错误,它似乎是关于Tensorflow版本,但我使用最新的版本(1.0),并且代码中没有直接使用tf的函数,我可以更改。下面是我的完整代码,无论如何我可以改变它,所以它的工作原理?谢谢你的帮助

import numpy
import pandas
import sys

from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasRegressor
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.datasets import load_svmlight_file


# define base mode
def baseline_model():
        # create model
        model = Sequential()
        model.add(Dense(68, activation="relu", kernel_initializer="normal", input_dim=68))
        model.add(Dense(1, kernel_initializer="normal"))
        # Compile model
        model.compile(loss='mean_squared_error', optimizer='adam')
        return model

X, y, query_id = load_svmlight_file(str(sys.argv[1]), query_id=True)
scaler = StandardScaler()
X = scaler.fit_transform(X.toarray())

# fix random seed for reproducibility
seed = 1
numpy.random.seed(seed)
# evaluate model with standardized dataset
estimator = KerasRegressor(build_fn=baseline_model, nb_epoch=100, batch_size=5, verbose=0)

kfold = KFold(n_splits=5, random_state=seed)
results = cross_val_score(estimator, X, y, cv=kfold)
print("Results: %.2f (%.2f) MSE" % (results.mean(), results.std()))

1 个答案:

答案 0 :(得分:1)

您可能正在使用旧的Tensorflow版本安装tensorflow 1.2.0rc2,您应该没问题。