InvalidArgumentError函数调用堆栈:train_function

时间:2020-07-28 21:04:06

标签: python-3.x tensorflow machine-learning keras

您好,我遇到了这个错误,无法解决任何想法。我正在尝试使用自己的数据集构建模型。因此,我选择了转移学习(VGG16),但仍然无法正常工作。先感谢您。 我正在使用Python 3.8X 最新版本的Tensorflow 2.2X 我试图建立一个可以dd

的分类器
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda, Dense, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
import numpy as np
from glob import glob
import matplotlib.pyplot as plt

IMAGE_SIZE = [224, 224]

train_path = 'dataset/Train'
val_path = 'dataset/validation'

vgg = VGG16(input_shape=IMAGE_SIZE + [3], weights='imagenet', include_top=False)


for layer in vgg.layers:
  layer.trainable = False
folders = glob('datasets/Train/*')
  x = Flatten()(vgg.output)
x = Dense(1000, activation='relu')(x)
prediction = Dense(len(folders), activation='softmax')(x)

# create a model object
model = Model(inputs=vgg.input, outputs=prediction)

model.summary()
model.compile(
  loss='categorical_crossentropy',
  optimizer='adam',
  metrics=['accuracy']
)
train_datagen = ImageDataGenerator(rescale = 1./255,
                                   shear_range = 0.2,
                                   zoom_range = 0.2,
                                   horizontal_flip = True)

test_datagen = ImageDataGenerator(rescale = 1./255)

training_set = train_datagen.flow_from_directory('dataset/train',
                                                 target_size = (224, 224),
                                                 batch_size = 32,
                                                 class_mode = 'categorical')

test_set = test_datagen.flow_from_directory('dataset/validation',
                                            target_size = (224, 224),
                                            batch_size = 32,
                                            class_mode = 'categorical')

r = model.fit(
  training_set,
  validation_data=test_set,
  epochs=5,
  steps_per_epoch=len(training_set),
  validation_steps=len(test_set)
)

下面是错误

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-46-a479a62b157d> in <module>
      2                       steps_per_epoch = 1,
      3                       epochs = 10,
----> 4                       validation_data = test_set
      5                      )

~/opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
    106   def _method_wrapper(self, *args, **kwargs):
    107     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
--> 108       return method(self, *args, **kwargs)
    109 
    110     # Running inside `run_distribute_coordinator` already.

~/opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1100                 _r=1):
   1101               callbacks.on_train_batch_begin(step)
-> 1102               tmp_logs = self.train_function(iterator)
   1103               if data_handler.should_sync:
   1104                 context.async_wait()

~/opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    794       else:
    795         compiler = "nonXla"
--> 796         result = self._call(*args, **kwds)
    797 
    798       new_tracing_count = self._get_tracing_count()

~/opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    821       # In this case we have created variables on the first call, so we run the
    822       # defunned version which is guaranteed to never create variables.
--> 823       return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    824     elif self._stateful_fn is not None:
    825       # Release the lock early so that multiple threads can perform the call

~/opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
   2920     with self._lock:
   2921       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2922     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2923 
   2924   @property

~/opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _filtered_call(self, args, kwargs, cancellation_manager)
   1856                            resource_variable_ops.BaseResourceVariable))],
   1857         captured_inputs=self.captured_inputs,
-> 1858         cancellation_manager=cancellation_manager)
   1859 
   1860   def _call_flat(self, args, captured_inputs, cancellation_manager=None):

~/opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1932       # No tape is watching; skip to running the function.
   1933       return self._build_call_outputs(self._inference_function.call(
-> 1934           ctx, args, cancellation_manager=cancellation_manager))
   1935     forward_backward = self._select_forward_and_backward_functions(
   1936         args,

~/opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
    555               inputs=args,
    556               attrs=attrs,
--> 557               ctx=ctx)
    558         else:
    559           outputs = execute.execute_with_cancellation(

~/opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError:  Reduction axis -1 is empty in shape [32,0]
     [[node ArgMax_1 (defined at <ipython-input-45-71c422cbdbf7>:4) ]] [Op:__inference_train_function_3410]

Function call stack:
train_function
[ ]:

0 个答案:

没有答案