我创建了一个Tensorflow模型,并试图使其适合我的两个训练集。
modelA <- keras_model_sequential()
modelA %>%
layer_dense(units = 2, activation = 'relu', input_shape = c(13)) %>%
layer_dropout(rate = 0.4) %>%
layer_dense(units = 128, activation = 'relu') %>%
layer_dropout(rate = 0.3) %>%
layer_dense(units = 10, activation = 'softmax')
modelA %>% compile(
loss = 'categorical_crossentropy',
optimizer = optimizer_rmsprop(),
metrics = c('accuracy')
)
但是,我不断收到此错误,我也不知道它是从哪里来的。
fitting <- modelA %>% fit(
neg_train, pos_train,
epochs = 30, batch_size = 128,
validation_split = 0.2
)
Error in py_call_impl(callable, dots$args, dots$keywords) :
InvalidArgumentError: indices[637] = 637 is not in [0, 637) [Op:GatherV2]
637是pos_train数据帧中的行数,但是我不知道该索引命令来自何处。有谁知道这里发生了什么或我如何解决?如果有帮助的话,这里是回溯。
25.
stop(structure(list(message = "InvalidArgumentError: indices[637] = 637 is not in [0, 637) [Op:GatherV2]",
call = py_call_impl(callable, dots$args, dots$keywords),
cppstack = structure(list(file = "", line = -1L, stack = c("1 reticulate.so 0x00000001831af98e _ZN4Rcpp9exceptionC2EPKcb + 222",
"2 reticulate.so 0x00000001831b7d05 _ZN4Rcpp4stopERKNSt3__112basic_stringIcNS0_11char_traitsIcEENS0_9allocatorIcEEEE + 53", ...
24.
raise_from at <string>#3
23.
raise_from_not_ok_status at ops.py#6653
22.
gather_v2 at gen_array_ops.py#3755
21.
gather at array_ops.py#4524
20.
wrapper at dispatch.py#180
19.
gather_v2 at array_ops.py#4541
18.
wrapper at dispatch.py#180
17.
_split at data_adapter.py#1335
16.
map_structure at nest.py#617
15.
train_validation_split at data_adapter.py#1338
14.
fit at training.py#797
13.
_method_wrapper at training.py#66
12.
(structure(function (...)
{
dots <- py_resolve_dots(list(...))
result <- py_call_impl(callable, dots$args, dots$keywords) ...
11.
do.call(object$fit, args)
10.
fit.keras.engine.training.Model(., neg_train, pos_train, epochs = 30,
batch_size = 128, validation_split = 0.2)
9.
fit(., neg_train, pos_train, epochs = 30, batch_size = 128, validation_split = 0.2)
8.
function_list[[k]](value)
7.
withVisible(function_list[[k]](value))
6.
freduce(value, `_function_list`)
5.
`_fseq`(`_lhs`)
4.
eval(quote(`_fseq`(`_lhs`)), env, env)
3.
eval(quote(`_fseq`(`_lhs`)), env, env)
2.
withVisible(eval(quote(`_fseq`(`_lhs`)), env, env))
1.
modelA %>% fit(neg_train, pos_train, epochs = 30, batch_size = 128,
validation_split = 0.2)
谢谢!