keras fit_generator从hdfstore读取块

时间:2018-04-26 22:23:22

标签: keras hdfstore

我尝试为Keras模型构建一个生成器,该模型将在大型hdf存储上进行训练。 为了加快培训速度,我预先计算了所有功能。已经在hdfstore中的单热编码。所以来自那的呼吁应该是直截了当的。

为了将我的数据块提供给网络,我尝试使用fit_generator,但很难让它运行起来。

发电机:

def myGenerator(myStore, generateFrom,generateTo):
 # Create empty arrays to contain batch of features and labels#
    while True:
        X = pd.read_hdf(myStore,'X',start=generateFrom,stop=generateTo)
        y = pd.read_hdf(myStore,'y',start=generateFrom,stop=generateTo)
        yield X,y

网络和拟合:

def get_model(shape):
    '''Create a keras model.'''
    inputlayer = Input(shape=shape)

    model = BatchNormalization()(inputlayer)
    model = Dense(1024, activation='relu')(model)
    model = Dropout(0.25)(model)
    model = BatchNormalization()(inputlayer)
    model = Dense(512, activation='relu')(model)
    model = Dropout(0.25)(model)
    model = BatchNormalization()(inputlayer)
    model = Dense(256, activation='relu')(model)
    model = Dropout(0.25)(model)
    model = BatchNormalization()(inputlayer)
    model = Dense(128, activation='relu')(model)
    model = Dropout(0.25)(model)

    # 11 because background noise has been taken out
    model = Dense(2, activation='tanh')(model)

    model = Model(inputs=inputlayer, outputs=model)

    return model
shape = (6603,10000)
model = get_model(shape)
model.compile(loss='mean_squared_error', optimizer=Adam(), metrics=['accuracy'])
#X = generator(myStore)
#Xt = generator(myStore)
labelbinarizer = LabelBinarizer()
y = labelbinarizer.fit_transform(y)
#yt = labelbinarizer.fit_transform(yt)

generateFrom = 0
for i in range(10):
    generateTo=generateFrom+10000
    model.fit_generator(
        generator=myGenerator(myStore,generateFrom,generateTo),
        epochs=1,
        steps_per_epoch=X[0].shape[0] // 1000)
    generateFrom=generateTo

我已经尝试了两种方法,将fit_generator放在一个循环中并插入范围(如上所示),还要处理生成器内的范围。两者都不起作用。目前正在进入

TypeError: 'generator' object is not subscriptable

可能我有一些误解,在这种情况下应该如何使用fit_generator()。大多数例子都是从图片中生成张量。

任何提示都表示赞赏。 感谢

1 个答案:

答案 0 :(得分:0)

函数read_hdf返回一个panda对象,需要将其转换为numpy数组。