Python Keras Tensorflow嵌入层索引[i,j] = k不在[0,max_features]中

时间:2017-12-08 21:19:14

标签: python tensorflow keras convolution multiclass-classification

我正在尝试进行作者识别,我的train_vecs_w2v.shape = (15663, 400)y_train.shape = (15663,3)有3个标签,一个是热编码的。 现在的问题是我在嵌入层中出错了。指数[0,X] = -1不在[0,15663]中。怎么解决这个?是我的代码还是Keras / Tensorflow?

print('Building Model')
n=19579
max_features = 15663
max_length = 400
EMBEDDING_DIM = 100
model7 = Sequential()

model7.add(Embedding(len(train_vecs_w2v), EMBEDDING_DIM, input_length=max_length, dtype='float32', trainable=True, weights=None, embeddings_initializer='uniform', embeddings_regularizer=None, activity_regularizer=None, embeddings_constraint=None))
print(model7.output_shape)
model7.add(Convolution1D(filters =128, kernel_size = 3, strides=1, activation='relu', use_bias=False, border_mode='same')) 
print(model7.output_shape)
model7.add(MaxPooling1D(pool_size = 3))
print(model7.output_shape)
model7.add(Convolution1D(filters = 64, kernel_size = 5, strides=1, activation='relu', border_mode='same'))
print(model7.output_shape)
model7.add(MaxPooling1D(pool_size = 5))
print(model7.output_shape)
model7.add(Flatten()) # model.output_shape == (None, 64*input_shape of convolution layer)
print(model7.output_shape)
model7.add(Dense(output_dim = 64, activation='relu')) # input_shape = (batch_size, input_dim)
print(model7.output_shape)
model7.add(Dense(output_dim = 32, activation='relu'))
print(model7.output_shape)
model7.add(Dense(output_dim = 3, activation='softmax'))

model7.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['categorical_accuracy'])

model7.fit(train_vecs_w2v, y_train_vec, epochs=50, batch_size=32, verbose=2)

我得到的错误

InvalidArgumentError (see above for traceback): indices[0,1] = -1 is not in [0, 15663)
     [[Node: embedding_1/Gather = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](embedding_1/embeddings/read, embedding_1/Cast)]]

1 个答案:

答案 0 :(得分:1)

我认为这里的问题是向量计数这个词。
应该是

INTEGER