如何将Keras生成器与tf.data API结合使用

时间:2018-10-03 21:45:58

标签: python tensorflow keras tensorflow-datasets

我正在尝试使用Keras预处理库中的生成器。我想尝试一下,因为Keras提供了强大的图像增强功能。但是,我不确定这是否真的可能。

这是我从Keras生成器生成tf数据集的方式:

def make_generator():
    train_datagen = ImageDataGenerator(rescale=1. / 255)
    train_generator = 
    train_datagen.flow_from_directory(train_dataset_folder,target_size=(224, 224), class_mode='categorical', batch_size=32)
    return train_generator

train_dataset = tf.data.Dataset.from_generator(make_generator,(tf.float32, tf.float32)).shuffle(64).repeat().batch(32)

请注意,如果您尝试直接将train_generator用作tf.data.Dataset.from_generator的参数,则会出现错误。但是,以上方法不会产生错误。

在会话中运行它以检查数据集的输出时,出现以下错误。

iterator = train_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
for i in range(100):
    sess.run(next_element)
  

找到1000个图像属于2类。   -------------------------------------------------- ------------------------- InvalidArgumentError追踪(最近的呼叫   持续)   /usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py   在_do_call(self,fn,* args)1291中尝试:   -> 1292返回fn(* args)1293(errors.OpError为e:

     

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py   在_run_fn中(feed_dict,fetch_list,target_list,选项,run_metadata)   第1276章(1276)   -> 1277选项,feed_dict,fetch_list,target_list,run_metadata)1278

     

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py   在_call_tf_sessionrun(自身,选项,feed_dict,fetch_list,   target_list,run_metadata)1366 self._session,选项,   feed_dict,fetch_list,target_list,   -> 1367 run_metadata)1368

     

InvalidArgumentError:无法在中批处理具有不同形状的张量   元素0。第一个元素的形状为[32,224,224,3],元素29的形状为   形状[8,224,224,3]。 [[{{node IteratorGetNext_2}} =   IteratorGetNextoutput_shapes = [,],   output_types = [DT_FLOAT,DT_FLOAT],   _device =“ / job:localhost /副本:0 /任务:0 /设备:CPU:0”]]

     

在处理上述异常期间,发生了另一个异常:

请让我知道是否有人对此有任何经验或知道其他方法。

更新

使用J.E.K.的建议后,我能够解决问题。

train_dataset = tf.data.Dataset.from_generator(make_generator,(tf.float32, tf.float32))

但是,当我将train_dataset赋予Keras .fit方法时,会出现以下错误。

model_regular.fit(train_dataset,steps_per_epoch=1000,epochs=2)
  

-------------------------------------------------- ---------------------------- ValueError Traceback(最近的呼叫   最后)在()   ----> 1个model_regular.fit(train_dataset,steps_per_epoch = 1000,epochs = 2)

     

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py   适合(自我,x,y,batch_size,时代,冗长,回调,   validate_split,validation_data,随机播放,class_weight,   sample_weight,initial_epoch,steps_per_epoch,validation_steps,   ** kwargs)1507 steps_name ='steps_per_epoch',1508 steps = steps_per_epoch,   -> 1509validation_split = validation_split)1510 1511#准备验证数据。

     

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py   在_standardize_user_data(self,x,y,sample_weight,class_weight,   batch_size,check_steps,steps_name,steps,validation_split)       948 x = self._dataset_iterator_cache [x]       949其他:   -> 950迭代器= x.make_initializable_iterator()       951 self._dataset_iterator_cache [x] =迭代器       952 x =迭代器

     

/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py   在make_initializable_iterator(自己,shared_name)中       119 with ops.colocate_with(iterator_resource):       120初始值设定项= gen_dataset_ops.make_iterator(self._as_variant_tensor(),   -> 121个iterator_resource)       122返回iterator_ops.Iterator(iterator_resource,初始值设定项,       123个self.output_types,self.output_shapes,

     

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_dataset_ops.py   如果_ctx为None或在make_iterator(数据集,迭代器,名称)中2542   不是_ctx._eager_context.is_eager:2543 _,_,_ op =   _op_def_lib._apply_op_helper(   -> 2544“ MakeIterator”,数据集=数据集,迭代器=迭代器,名称=名称)2545 return _op 2546 _result = None

     

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py   在_apply_op_helper中(自己,op_type_name,name,** keywords)       348#需要将所有参数放到列表中。       349#pylint:禁用=受保护的访问   -> 350克= ops._get_graph_from_inputs(_Flatten(keywords.values()))       351#pylint:启用=受保护的访问       352,除了AssertionError为e:

     

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py   在_get_graph_from_inputs(op_input_list,graph)5659图形中   = graph_element.graph 5660 elif original_graph_element不是None:   -> 5661 _assert_same_graph(original_graph_element,graph_element)5662 Elif graph_element.graph不是图:
  5663引发ValueError(“%s不是来自传入的图形。”%   graph_element)

     

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py   在_assert_same_graph(原始项目,项目)5595中,如果   original_item.graph不是item.graph:提高5596   ValueError(“%s必须与%s来自同一张图表。”%(项目,   -> 5597 original_item))5598 5599

     

ValueError:Tensor(“ IteratorV2:0”,shape =(),dtype = resource)必须为   来自与Tensor(“ FlatMapDataset:0”,shape =()相同的图,   dtype = variant)。

这是错误还是Keras拟合方法不打算以这种方式使用?

2 个答案:

答案 0 :(得分:2)

我试图通过一个简单的示例重现您的结果,但发现当在生成器函数和tf.data中使用批处理时,您得到不同的输出形状。

Keras函数train_datagen.flow_from_directory(batch_size=32)已返回形状为[batch_size, width, height, depth]的数据。如果使用tf.data.Dataset().batch(32),则将输出数据再次批处理为形状[batch_size, batch_size, width, height, depth]

这可能由于某些原因导致了您的问题。

答案 1 :(得分:1)

不应该

model_regular.fit(train_dataset,steps_per_epoch=1000,epochs=2)

model_regular.fit(train_dataset.make_one_shot_iterator(),steps_per_epoch=1000,epochs=2)

按照this answer