使用生成器输入数据集,但得到IndexError

时间:2020-06-14 07:15:10

标签: python numpy tensorflow

model.fit(x,y, epochs=10000, batch_size=1)

以上代码可以正常工作。当我使用函数在模型中输入数据时,出现了问题。

model.fit(GData(), epochs=10000, batch_size=1)

per_sample_losses = loss_fn.call(targets [i],outs [i])
IndexError:列表索引超出范围

GData()函数如下:

def GData():
  return (x,y)

x是一个维度为(2,63,85)的numpy数组
y是一个尺寸为(2,63,41000)的numpy数组

这是整个代码:

import os
import tensorflow as tf
import numpy as np

def MSE( y_true, y_pred):
    error = tf.math.reduce_mean(tf.math.square(y_true-y_pred))
    return error

data = np.load("Data.npz")
x = data['x'] # (2,63,   85)
y = data['y'] # (2,63,41000)

frame = x.shape[1]
InSize = x.shape[2]
OutSize = y.shape[2]

def GData():
    return (x,y)


model = tf.keras.Sequential()
model.add(tf.keras.layers.GRU(1000, return_sequences=True, input_shape=(frame,InSize)))
model.add(tf.keras.layers.Dense(OutSize))

model.compile(optimizer='adam',
              loss=MSE)#'mean_squared_error')
model.fit(GData(), epochs=10000, batch_size=1)

1 个答案:

答案 0 :(得分:1)

首先,函数GData实际上不是生成器,因为它正在返回值而不是产生值。无论如何,我们应该看一下herefit()方法及其文档。 由此可见,fit()的前两个参数是x和y。进一步讲,我们看到x限于几种类型。即,生成器,numpy数组,tf.data.Datasets等。文档中需要注意的重要一点是,如果x是生成器,则它必须是A generator or keras.utils.Sequence returning (inputs, targets)。我假设这就是您想要的。在这种情况下,您将需要修改GData函数,使其实际上是一个生成器。可以这样做

batch_size = 1
EPOCHS = 10000
def GData():
    for _ in range(EPOCHS): # Iterate through epochs. Note that this can be changed to be while True so that the generator yields indefinitely. The model will stop training after the amount of epochs you specify in the fit method.
        for i in range(0, len(x), batch_size): # Iterate through batches
            yield (x[i:batch_size], y[i:batch_size]) # Yield batches for training

然后,您必须在fit()调用中指定每个时期的步数,以便模型知道何时在每个时期停止。

model.fit(GData(), epochs=EPOCHS, steps_per_epoch=x.shape[0]//batch_size)