Keras HDF5Matrix不适用于model.fit()

时间:2017-06-05 20:53:54

标签: python keras

我正在尝试使用keras.utils.io_utils.HDF5Matrix在训练时处理大型数据集,但我在尝试训练时遇到了问题。运行此示例:

from keras.models import Sequential
from keras.layers import Dense
from keras.utils.io_utils import HDF5Matrix
import numpy as np

def create_dataset():
    import h5py
    X = np.random.randn(200,10).astype('float32')
    y = np.random.randint(0, 2, size=(200,1))
    f = h5py.File('test.h5', 'w')
    # Creating dataset to store features
    X_dset = f.create_dataset('my_data', (200,10), dtype='f')
    X_dset[:] = X
    # Creating dataset to store labels
    y_dset = f.create_dataset('my_labels', (200,1), dtype='i')
    y_dset[:] = y
    f.close()

create_dataset()

# Instantiating HDF5Matrix for the training set, which is a slice of the first 150 elements
X_train = HDF5Matrix('test.h5', 'my_data', start=0, end=150)
y_train = HDF5Matrix('test.h5', 'my_labels', start=0, end=150)

# Likewise for the test set
X_test = HDF5Matrix('test.h5', 'my_data', start=150, end=200)
y_test = HDF5Matrix('test.h5', 'my_labels', start=150, end=200)


model = Sequential()
model.add(Dense(64, input_shape=(10,), activation='relu'))
model.add(Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy', optimizer='sgd')

# Note: you have to use shuffle='batch' or False with HDF5Matrix
model.fit(X_train, y_train, batch_size=32, shuffle='batch')

model.evaluate(X_test, y_test, batch_size=32)

返回错误:

AttributeErrorTraceback (most recent call last)
<ipython-input-3-02900751f245> in <module>()
     38 
     39 # Note: you have to use shuffle='batch' or False with HDF5Matrix
---> 40 model.fit(X_train, y_train, batch_size=32, shuffle='batch')
     41 
     42 model.evaluate(X_test, y_test, batch_size=32)

...

/opt/conda/lib/python2.7/site-packages/keras/utils/io_utils.pyc in __getitem__(self, key)
     63 
     64     def __getitem__(self, key):
---> 65         start, stop = key.start, key.stop
     66         if isinstance(key, slice):
     67             if start is None:

AttributeError: 'list' object has no attribute 'start'

我想知道这是否只是由于最近的更新而导致的错误,因为这个示例似乎在过去对其他人有效。甚至运行代码:

print(y_train[10])

返回类似的错误:

AttributeError: 'int' object has no attribute 'start'

尽管

print(y_train[10:11])

实际上工作并打印出来:

[[1]]

0 个答案:

没有答案