将tf.data数据集API与Keras结合使用时,出现“维度0的切片索引0超出范围”的异常

时间:2018-08-23 07:05:20

标签: tensorflow keras

我厌倦了使用tf.data api加载mnist数据集来训练一个简单的模型,如下所示,但是却遇到了“维度0的切片索引0超出范围”的异常。我想知道我的代码做错了什么。

import math
import tensorflow as tf
import numpy as np

batch_size = 32


def load_data():
    mnist = tf.keras.datasets.mnist
    (train_data, train_label), (validation_data, validation_label) = mnist.load_data()
    train_data, validation_data = train_data / 255.0, validation_data / 255.0
    train_label = train_label.astype(np.float32)
    return train_data, train_label


def build_model():
    model = tf.keras.models.Sequential([
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(512, activation=tf.nn.relu),
      tf.keras.layers.Dropout(0.2),
      tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])

    model.compile(optimizer=tf.train.AdamOptimizer(),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    return model


train_data, train_label = load_data()
train_sample_count = len(train_data)

train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))

train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.repeat()

iter = train_dataset.make_one_shot_iterator()
train_x, train_y = iter.get_next()


model = build_model()
model.fit(
    train_dataset,
    epochs=10,
    steps_per_epoch=math.ceil(train_sample_count/batch_size)
)

完整的堆栈跟踪如下所示

  

第1/10版

     

2018-08-23 16:14:49.485165:我tensorflow / core / platform / cpu_feature_guard.cc:141]您的CPU支持该TensorFlow二进制文件未编译为使用的指令:AVX2 FMA

     

2018-08-23 16:15:25.588057:W tensorflow / core / framework / op_kernel.cc:1275] OP_REQUIRES在strided_slice_op.cc:105处失败:无效参数:维度0的切片索引0超出范围。 / p>      

2018-08-23 16:15:26.852912:W tensorflow / core / framework / op_kernel.cc:1275] OP_REQUIRES在iterator_ops.cc:910处失败:已取消:

     

回溯(最近通话最近一次):

     

文件

中的文件“ /Users/xievi/workspace/speaker_verification/mnist_with_dataset.py”,第35行      

steps_per_epoch = math.ceil(train_sample_count / batch_size)

     

适合的文件“ /Users/xievi/shared_python3_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py”,第1363行

     

validation_steps = validation_steps)

     

在fit_loop中,文件“ /Users/xievi/shared_python3_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py”,第205行      outs = f(ins)

     

调用

中的文件“ /Users/xievi/shared_python3_venv/lib/python3.6/site-packages/tensorflow/python/keras/backend.py”第2914行      

获取= self._callable_fn(* array_vals)

     

在“ 调用

”中的文件“ /Users/xievi/shared_python3_venv/lib/python3.6/site-packages/tensorflow/python/client/session.py”中的第1382行      

run_metadata_ptr)

     

退出

中的文件“ /Users/xievi/shared_python3_venv/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py”,第519行      

c_api.TF_GetCode(self.status.status))

     

tensorflow.python.framework.errors_impl.InvalidArgumentError:维度0的切片索引0超出范围。

     

[[节点:flatten / strided_slice = StridedSlice [索引= DT_INT32,T = DT_INT32,begin_mask = 0,省略号_mask = 0,end_mask = 0,new_axis_mask = 0,shrink_axis_mask = 1,_device =“ / job:localhost / replica :0 /任务:0 /设备:CPU:0“](展平/形状,指标/ acc /常量,训练/ TFOptimizer /梯度/ dense_1 / Softmax_grad /总和/ reduction_indices,训练/ TFOptimizer / gradients / dense_1 / Softmax_grad / Sum / reduction_indices)]]

1 个答案:

答案 0 :(得分:0)

仅作记录,我将分享我的修正。

这是我得到的错误:

ValueError: slice index 0 of dimension 0 out of bounds. for '{{node strided_slice}} = StridedSlice[Index=DT_INT32, T=DT_INT32, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1](Shape, strided_slice/stack, strided_slice/stack_1, strided_slice/stack_2)' with input shapes: [0], [1], [1], [1] and with computed input tensors: input[1] = <0>, input[2] = <1>, input[3] = <1>.

此错误阻止了我的Keras模型训练。我尝试将数据点线从Keras自定义生成器更改为tf.data,但错误没有消失。最后,我能够通过将损失函数从以下位置更改来开始训练:

loss = 'sparse_categorical_crossentropy'

收件人:

loss = 'categorical_crossentropy'

我不知道原因,但这对我有用。

这是我的错误的完整记录:

    Epoch 1/100
Traceback (most recent call last):
  File "classify_cnn_generator.py", line 175, in <module>
    history = model.fit(train_generator, 
  File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1098, in fit
    tmp_logs = train_function(iterator)
  File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 823, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 696, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3065, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:806 train_function  *
        return step_function(self, iterator)
    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:796 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:1211 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2585 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2945 _call_for_each_replica
        return fn(*args, **kwargs)
    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:789 run_step  **
        outputs = model.train_step(data)
    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:748 train_step
        loss = self.compiled_loss(
    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/keras/engine/compile_utils.py:212 __call__
        batch_dim = array_ops.shape(y_t)[0]
    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:201 wrapper
        return target(*args, **kwargs)
    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py:1013 _slice_helper
        return strided_slice(
    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:201 wrapper
        return target(*args, **kwargs)
    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py:1186 strided_slice
        op = gen_array_ops.strided_slice(
    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/ops/gen_array_ops.py:10347 strided_slice
        _, _, _op, _outputs = _op_def_library._apply_op_helper(
    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py:742 _apply_op_helper
        op = g._create_op_internal(op_type_name, inputs, dtypes=None,
    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:591 _create_op_internal
        return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3477 _create_op_internal
        ret = Operation(
    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:1974 __init__
        self._c_op = _create_c_op(self._graph, node_def, inputs,
    /home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:1815 _create_c_op
        raise ValueError(str(e))

    ValueError: slice index 0 of dimension 0 out of bounds. for '{{node strided_slice}} = StridedSlice[Index=DT_INT32, T=DT_INT32, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1](Shape, strided_slice/stack, strided_slice/stack_1, strided_slice/stack_2)' with input shapes: [0], [1], [1], [1] and with computed input tensors: input[1] = <0>, input[2] = <1>, input[3] = <1>.