无法在加载的模型上使用predict_generator

时间:2019-04-21 17:38:06

标签: keras

我试图加载Keras模型仅用于预测(即,根据Pepslee's post here,我不必编译模型)。

当我尝试使用model.predict_generator()时,我得到:

Using TensorFlow backend.
Exception in thread Thread-1:
Traceback (most recent call last):
  File "/user/pkgs/anaconda2/lib/python2.7/threading.py", line 801, in __bootstrap_inner
    self.run()
  File "/user/pkgs/anaconda2/lib/python2.7/threading.py", line 754, in run
    self.__target(*self.__args, **self.__kwargs)
  File "/user/pkgs/anaconda2/lib/python2.7/site-packages/keras/utils/data_utils.py", line 559, in _run
    sequence = list(range(len(self.sequence)))
ValueError: __len__() should return >= 0

我正在使用Tensorflow版本1.12.0,Keras版本2.2.4。我需要使用这些版本来确保与我无法控制的cuDNN版本兼容。

如何解决此错误?

编辑

有人问我一个例子。不幸的是,这里有太多专有信息让我无法提供更多细节,但这是裸露的骨骼(请注意,该模型实际上不是LSTM):

class LSTMmodel():

    def __init__(self, hid1 = 10, batch_size=32, mode='test'):
        self.hid_dim_1 = hid1 

        self.t_per_e, self.test_generator = self.read_data()

        #Load the entire fitted model
        model_name = ''.join(glob.glob('*model.h5'))
        self.__model = load_model(model_name, compile=False)

    def read_data(self):

        num_test_minibatches = 10
        test_IDs = range(111, 111+10)
        params = {'list_IDs': test_IDs, 'batch_size': self.batch_size, 'n_vars': 354}

        test_generator = DataGenerator(test_IDs, **params)
        t_per_e = int(len(test_IDs) - self.batch_size + 1)

        return t_per_e, test_generator

    def lstm_model():
        #Model building here. Not needed, since not compiling the model
        return 0

    def lstm_predict(self):
        pred = self.__model.predict_generator(self.test_generator, self.t_per_e)
        return pred

class DataGenerator(keras.utils.Sequence):

    #Other methods in here as necessary

    def __len__(self):
        'Denotes the number of batches per epoch'
        batches_per_epoch = int(np.floor(len(self.list_IDs) - self.batch_size + 1))
        return batches_per_epoch

    def __data_generation(self, other_params_here):
        'Generates data containing batch_size samples'
        return preprocessed_data

def test_lstm():
    test_inst = LSTMmodel(hid1=10) #hid1 is a hyperparameter
    test_prediction = test_inst.lstm_predict()
    return test_prediction


if __name__ == '__main__':
    testvals = test_lstm()

基本上,工作流程是:

1)test_lstm()创建LSTMmodel类的实例,然后调用lstm_predict

2)lstm_predict使用predict_generator来获取测试集和要生成的示例数的生成器(here中的steps)。

3)测试集的生成器被创建为类DataGenerator()的{​​{1}}方法中类read_data()的实例。重要的是,测试数据生成器的创建方式与培训数据生成器和验证数据生成器的创建方式相同。

4)LSTMmodel()是通过在LSTMmodel()类的self.__model方法中加载经过全面训练的模型而创建的。

如何摆脱错误?

0 个答案:

没有答案