将Keras生成器转换为Tensorflow数据集以训练Resnet50

时间:2020-06-15 13:05:42

标签: python-3.x keras neural-network tensorflow-datasets

我正在将python代码从keras命名空间转换为tf.keras。它训练Resnet50。 新的Model.fit()方法无法为我的简单生成器找到适配器,并且validation_data甚至不再支持生成器。因此,我正在尝试使用tensorflow.data.Dataset.from_generator方法将其转换为Dataset。

图像为灰度图像,并以原始字节存储-一个像素一个字节。生成器有这样的行

        def __next__( self ):
            return self.next()

        def __call__( self ):
            return self.next()

        def next( self ):
            #reading files
            ...

            resultLabels = numpy.zeros( ( count, len( classes ) ), "float32" )
            resultImages = numpy.zeros( ( count, patchSize, patchSize, 3 ), "float32" )

            #filling result with images and labels
                ...
                fileBytes = numpy.reshape( numpy.fromfile( self.ImageLabelsAndPaths[i][1], "uint8" ), (patchSize, patchSize), "F" ).astype( "float32" )

                imageWithChannels = numpy.zeros( ( patchSize, patchSize, 3 ), "float32" )
                # Because Resnet50 requires RGB images and we have grayscale
                imageWithChannels[:,:,0] = fileBytes
                imageWithChannels[:,:,1] = fileBytes
                imageWithChannels[:,:,2] = fileBytes

                resultImages[i - cursor] = imageWithChannels

            return ( resultImages, resultLabels )

因此resultImages是一个长度为batch_size = 16的数组,其中包含图像像素数组。 Numpy.shape是(16,256,256,3)并且resultLabels shape是(16,3)-目前有3个类。

下一步,我将其转换为数据集

            trainGenerator = FileIterator( "train" )
            trainDataset = tf.data.Dataset.from_generator( trainGenerator, (tf.float32, tf.float32), (tf.TensorShape([batchSize, patchSize, patchSize, 3]), tf.TensorShape([batchSize, len(classes)]) ) )
            validationGenerator = FileIterator( "validate" )
            validationDataset = tf.data.Dataset.from_generator( validationGenerator, (tf.float32, tf.float32), (tf.TensorShape([batchSize, patchSize, patchSize, 3]), tf.TensorShape([batchSize, len(classes)]) ) )

但我遇到错误

TypeError: `generator` yielded an element that did not match the expected structure. The expected structure was (tf.float32, tf.float32), but the yielded element was [[[[185. 185. 185.]
   [158. 158. 158.]
   [145. 145. 145.]
   ...

Dataset.from_generator的代码示例在元组中具有第二个数组作为数组,并且类似output_types =(tf.int64,tf.int64)。我想它在那里工作。

试图添加数组以键入结果会导致另一个错误

TypeError: unhashable type: 'list'

我应该进行哪些更改才能使其正常工作?

1 个答案:

答案 0 :(得分:1)

好吧,在又花了两天时间之后,试图修复一些真正令人误解的错误,并在python.exe最终运行时崩溃了,我才能够将生成器转换为tensorflow数据集。

我无法使它与批处理一起使用,并且numpy.array不被Dataset接受,因为它在Dataset的世界中不是顺序的,并且返回一个元组很重要,不知道带有“ yield”的示例”和“返回数据,标签”。

发电机

        def __iter__(self):
            return self

        def __call__( self ):
            return self

        def __len__(self):
            return self.TotalCount

        def __next__( self ):
            ...
            resultLabel = numpy.zeros( len( classes ), "float32" )
            resultImage = numpy.zeros( ( patchSize, patchSize, 3 ), "float32" )
            # fill those two
            ...

            return (resultImage.tolist(), resultLabel.tolist())

和数据集+ model.fit

            trainGenerator = FileIterator( "train" )
            validationGenerator = FileIterator( "validate" )

            trainDataset = tf.data.Dataset.from_generator( trainGenerator, output_types=(tf.float32, tf.float32), output_shapes=(tf.TensorShape([patchSize, patchSize, 3]), tf.TensorShape([len(classes)]) ) )
            trainDataset = trainDataset.batch( batchSize )
            validationDataset = tf.data.Dataset.from_generator( validationGenerator, output_types=(tf.float32, tf.float32), output_shapes=(tf.TensorShape([patchSize, patchSize, 3]), tf.TensorShape([len(classes)]) ) )
            validationDataset = validationDataset.batch( batchSize )


            trainResult = model.fit( x = trainDataset,
                                     epochs = epochsForDenseLayer,
                                     steps_per_epoch = trainGenerator.StepsPerEpoch,
                                     verbose = 2,
                                     validation_data = validationDataset,
                                     validation_steps = validationGenerator.StepsPerEpoch,
                                     validation_freq = 1,
                                     shuffle = False, # already shuffled by generator
                                     workers = cpuCoresCount,
                                     use_multiprocessing = False
                                    )