我正在尝试使用简单的Keras顺序模型创建用于音频识别的数据集。
这是我用来创建模型的功能:
def dnn_model(input_shape, output_shape):
model = keras.Sequential()
model.add(keras.Input(input_shape))
model.add(layers.Flatten())
model.add(layers.Dense(512, activation = "relu"))
model.add(layers.Dense(output_shape, activation = "softmax"))
model.compile( optimizer='adam',
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
metrics=['acc'])
model.summary()
return model
我正在使用此Generator函数生成我的trainingsdata:
def generator(x_dirs, y_dirs, hmm, sampling_rate, parameters):
window_size_samples = tools.sec_to_samples(parameters['window_size'], sampling_rate)
window_size_samples = 2**tools.next_pow2(window_size_samples)
hop_size_samples = tools.sec_to_samples(parameters['hop_size'],sampling_rate)
for i in range(len(x_dirs)):
features = fe.compute_features_with_context(x_dirs[i],**parameters)
praat = tools.praat_file_to_target( y_dirs[i],
sampling_rate,
window_size_samples,
hop_size_samples,
hmm)
yield features,praat
变量x_dirs
和y_dirs
包含标签和音频文件的路径列表。我总共获得了8623个文件来训练我的模型。这就是我训练模型的方式:
def train_model(model, model_dir, x_dirs, y_dirs, hmm, sampling_rate, parameters, steps_per_epoch=10,epochs=10):
model.fit((generator(x_dirs, y_dirs, hmm, sampling_rate, parameters)),
epochs=epochs,
batch_size=steps_per_epoch)
return model
我现在的问题是,如果我传递所有8623文件,它将在第一个时期使用所有8623个文件来训练模型,并在第一个时期后抱怨需要steps_per_epoch * epochs
批来训练模型。
我仅对带有切片列表的8623个文件中的10个进行了测试,但随后Tensorflow抱怨需要100个批次。
那么我如何让Generator生成最有效的数据?我一直认为steps_per_epoch
仅限制每个时期接收的数据。
答案 0 :(得分:1)
fit函数将耗尽您的生成器,也就是说,一旦生成了您所有的8623批次,它将不再能够生成批次。
您要解决这样的问题:
def generator(x_dirs, y_dirs, hmm, sampling_rate, parameters, epochs=1):
for epoch in range(epochs): # or while True:
window_size_samples = tools.sec_to_samples(parameters['window_size'], sampling_rate)
window_size_samples = 2**tools.next_pow2(window_size_samples)
hop_size_samples = tools.sec_to_samples(parameters['hop_size'],sampling_rate)
for i in range(len(x_dirs)):
features = fe.compute_features_with_context(x_dirs[i],**parameters)
praat = tools.praat_file_to_target( y_dirs[i],
sampling_rate,
window_size_samples,
hop_size_samples,
hmm)
yield features,praat