我厌倦了使用tf.data api加载mnist数据集来训练一个简单的模型,如下所示,但是却遇到了“维度0的切片索引0超出范围”的异常。我想知道我的代码做错了什么。
import math
import tensorflow as tf
import numpy as np
batch_size = 32
def load_data():
mnist = tf.keras.datasets.mnist
(train_data, train_label), (validation_data, validation_label) = mnist.load_data()
train_data, validation_data = train_data / 255.0, validation_data / 255.0
train_label = train_label.astype(np.float32)
return train_data, train_label
def build_model():
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer=tf.train.AdamOptimizer(),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
train_data, train_label = load_data()
train_sample_count = len(train_data)
train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.repeat()
iter = train_dataset.make_one_shot_iterator()
train_x, train_y = iter.get_next()
model = build_model()
model.fit(
train_dataset,
epochs=10,
steps_per_epoch=math.ceil(train_sample_count/batch_size)
)
完整的堆栈跟踪如下所示
第1/10版
2018-08-23 16:14:49.485165:我tensorflow / core / platform / cpu_feature_guard.cc:141]您的CPU支持该TensorFlow二进制文件未编译为使用的指令:AVX2 FMA
2018-08-23 16:15:25.588057:W tensorflow / core / framework / op_kernel.cc:1275] OP_REQUIRES在strided_slice_op.cc:105处失败:无效参数:维度0的切片索引0超出范围。 / p>
2018-08-23 16:15:26.852912:W tensorflow / core / framework / op_kernel.cc:1275] OP_REQUIRES在iterator_ops.cc:910处失败:已取消:
回溯(最近通话最近一次):
文件
中的文件“ /Users/xievi/workspace/speaker_verification/mnist_with_dataset.py”,第35行steps_per_epoch = math.ceil(train_sample_count / batch_size)
适合的文件“ /Users/xievi/shared_python3_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py”,第1363行
validation_steps = validation_steps)
在fit_loop中,文件“ /Users/xievi/shared_python3_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_arrays.py”,第205行 outs = f(ins)
在调用
中的文件“ /Users/xievi/shared_python3_venv/lib/python3.6/site-packages/tensorflow/python/keras/backend.py”第2914行获取= self._callable_fn(* array_vals)
在“ 调用
”中的文件“ /Users/xievi/shared_python3_venv/lib/python3.6/site-packages/tensorflow/python/client/session.py”中的第1382行run_metadata_ptr)
退出
中的文件“ /Users/xievi/shared_python3_venv/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py”,第519行c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError:维度0的切片索引0超出范围。
[[节点:flatten / strided_slice = StridedSlice [索引= DT_INT32,T = DT_INT32,begin_mask = 0,省略号_mask = 0,end_mask = 0,new_axis_mask = 0,shrink_axis_mask = 1,_device =“ / job:localhost / replica :0 /任务:0 /设备:CPU:0“](展平/形状,指标/ acc /常量,训练/ TFOptimizer /梯度/ dense_1 / Softmax_grad /总和/ reduction_indices,训练/ TFOptimizer / gradients / dense_1 / Softmax_grad / Sum / reduction_indices)]]
答案 0 :(得分:0)
仅作记录,我将分享我的修正。
这是我得到的错误:
ValueError: slice index 0 of dimension 0 out of bounds. for '{{node strided_slice}} = StridedSlice[Index=DT_INT32, T=DT_INT32, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1](Shape, strided_slice/stack, strided_slice/stack_1, strided_slice/stack_2)' with input shapes: [0], [1], [1], [1] and with computed input tensors: input[1] = <0>, input[2] = <1>, input[3] = <1>.
此错误阻止了我的Keras模型训练。我尝试将数据点线从Keras自定义生成器更改为tf.data,但错误没有消失。最后,我能够通过将损失函数从以下位置更改来开始训练:
loss = 'sparse_categorical_crossentropy'
收件人:
loss = 'categorical_crossentropy'
我不知道原因,但这对我有用。
这是我的错误的完整记录:
Epoch 1/100
Traceback (most recent call last):
File "classify_cnn_generator.py", line 175, in <module>
history = model.fit(train_generator,
File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
return method(self, *args, **kwargs)
File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1098, in fit
tmp_logs = train_function(iterator)
File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
result = self._call(*args, **kwds)
File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 823, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 696, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3065, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:806 train_function *
return step_function(self, iterator)
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:796 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:1211 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2585 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2945 _call_for_each_replica
return fn(*args, **kwargs)
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:789 run_step **
outputs = model.train_step(data)
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:748 train_step
loss = self.compiled_loss(
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/keras/engine/compile_utils.py:212 __call__
batch_dim = array_ops.shape(y_t)[0]
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:201 wrapper
return target(*args, **kwargs)
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py:1013 _slice_helper
return strided_slice(
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py:201 wrapper
return target(*args, **kwargs)
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/ops/array_ops.py:1186 strided_slice
op = gen_array_ops.strided_slice(
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/ops/gen_array_ops.py:10347 strided_slice
_, _, _op, _outputs = _op_def_library._apply_op_helper(
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py:742 _apply_op_helper
op = g._create_op_internal(op_type_name, inputs, dtypes=None,
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py:591 _create_op_internal
return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3477 _create_op_internal
ret = Operation(
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:1974 __init__
self._c_op = _create_c_op(self._graph, node_def, inputs,
/home/deogun/alali/.conda/envs/mytf2.3/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:1815 _create_c_op
raise ValueError(str(e))
ValueError: slice index 0 of dimension 0 out of bounds. for '{{node strided_slice}} = StridedSlice[Index=DT_INT32, T=DT_INT32, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1](Shape, strided_slice/stack, strided_slice/stack_1, strided_slice/stack_2)' with input shapes: [0], [1], [1], [1] and with computed input tensors: input[1] = <0>, input[2] = <1>, input[3] = <1>.