我试图加载Keras模型仅用于预测(即,根据Pepslee's post here,我不必编译模型)。
当我尝试使用model.predict_generator()
时,我得到:
Using TensorFlow backend.
Exception in thread Thread-1:
Traceback (most recent call last):
File "/user/pkgs/anaconda2/lib/python2.7/threading.py", line 801, in __bootstrap_inner
self.run()
File "/user/pkgs/anaconda2/lib/python2.7/threading.py", line 754, in run
self.__target(*self.__args, **self.__kwargs)
File "/user/pkgs/anaconda2/lib/python2.7/site-packages/keras/utils/data_utils.py", line 559, in _run
sequence = list(range(len(self.sequence)))
ValueError: __len__() should return >= 0
我正在使用Tensorflow版本1.12.0,Keras版本2.2.4。我需要使用这些版本来确保与我无法控制的cuDNN版本兼容。
如何解决此错误?
编辑
有人问我一个例子。不幸的是,这里有太多专有信息让我无法提供更多细节,但这是裸露的骨骼(请注意,该模型实际上不是LSTM):
class LSTMmodel():
def __init__(self, hid1 = 10, batch_size=32, mode='test'):
self.hid_dim_1 = hid1
self.t_per_e, self.test_generator = self.read_data()
#Load the entire fitted model
model_name = ''.join(glob.glob('*model.h5'))
self.__model = load_model(model_name, compile=False)
def read_data(self):
num_test_minibatches = 10
test_IDs = range(111, 111+10)
params = {'list_IDs': test_IDs, 'batch_size': self.batch_size, 'n_vars': 354}
test_generator = DataGenerator(test_IDs, **params)
t_per_e = int(len(test_IDs) - self.batch_size + 1)
return t_per_e, test_generator
def lstm_model():
#Model building here. Not needed, since not compiling the model
return 0
def lstm_predict(self):
pred = self.__model.predict_generator(self.test_generator, self.t_per_e)
return pred
class DataGenerator(keras.utils.Sequence):
#Other methods in here as necessary
def __len__(self):
'Denotes the number of batches per epoch'
batches_per_epoch = int(np.floor(len(self.list_IDs) - self.batch_size + 1))
return batches_per_epoch
def __data_generation(self, other_params_here):
'Generates data containing batch_size samples'
return preprocessed_data
def test_lstm():
test_inst = LSTMmodel(hid1=10) #hid1 is a hyperparameter
test_prediction = test_inst.lstm_predict()
return test_prediction
if __name__ == '__main__':
testvals = test_lstm()
基本上,工作流程是:
1)test_lstm()
创建LSTMmodel类的实例,然后调用lstm_predict
。
2)lstm_predict
使用predict_generator
来获取测试集和要生成的示例数的生成器(here中的steps
)。
3)测试集的生成器被创建为类DataGenerator()
的{{1}}方法中类read_data()
的实例。重要的是,测试数据生成器的创建方式与培训数据生成器和验证数据生成器的创建方式相同。
4)LSTMmodel()
是通过在LSTMmodel()类的self.__model
方法中加载经过全面训练的模型而创建的。
如何摆脱错误?