在Keras中使用TF数据集API的既定方法是使用make_one_shot_iterator()来馈送model.fit,但是此迭代器仅适用于一个纪元

时间:2019-03-31 19:28:12

标签: tensorflow keras tensorflow-datasets tf.keras

编辑:

为弄清为什么这个问题与建议的重复项不同,此SO问题在这些建议的重复项之后进行,Keras对这些SO问题中描述的技术究竟做了什么。建议的重复项使用class Program { public class HistoricValue { public string Name { get; set; } public DateTime Lastdate { get; set; } public Double Value { get; set; } } private static void Display(List<HistoricValue> cs) { Console.WriteLine(); foreach (HistoricValue item in cs) { Console.WriteLine("{0} {1} {2} ", item.Name, item.Lastdate.ToString(), item.Value); } } static void Main(string[] args) { HistoricValue newValue = new HistoricValue(); List<HistoricValue> Historicals = new List<HistoricValue>(); newValue.Name= "Some name 1"; newValue.Lastdate = DateTime.Parse("2018-05-08"); newValue.Value = 310.1; Historicals.Add(new HistoricValue () { Name=newValue.Name, Lastdate= newValue.Lastdate, Value = newValue.Value }); Historicals.Add(newValue); Console.WriteLine("Expected output: Twice Some Name 1"); Display(Historicals); newValue.Name = "Some name 2"; newValue.Lastdate = DateTime.Parse("2018-09-09"); newValue.Value = 210.1; Historicals.Add(new HistoricValue() { Name = newValue.Name, Lastdate = newValue.Lastdate, Value = newValue.Value }); Historicals.Add(newValue); Console.WriteLine("\nExpected output: Twice Some Name 1 and twice somename 2"); Display(Historicals); Console.WriteLine("\nReceived output: once Some name 1 and tree times somename 2"); Console.WriteLine("\nnewValue get assigned values, what is stored in the list is the pointer to values, so item 2,3,4 will point to the same values in memory."); List<HistoricValue> Historicals2 = new List<HistoricValue>(); Console.WriteLine("\nRCorrect ways to fill the list can be by using a constructor"); Historicals2.Add(new HistoricValue() { Name = "Some name 1", Lastdate = DateTime.Parse("2018-05-08"), Value = 310.1 }); Historicals2.Add(new HistoricValue() { Name = "Some name 2", Lastdate = DateTime.Parse("2018-06-08"), Value = 100.1 }); Console.WriteLine("Expected output: Some Name 1 and Somename 2"); Display(Historicals2); Console.WriteLine("\nOr add with specifically creating a new posistion in the list and add it."); newValue.Name = "Some name 3"; newValue.Lastdate = DateTime.Parse("2018-05-08"); newValue.Value = 310.1; Historicals2.Add(new HistoricValue() { Name = newValue.Name, Lastdate = newValue.Lastdate, Value = newValue.Value }); newValue.Name = "Some name 4"; newValue.Lastdate = DateTime.Parse("2018-09-09"); newValue.Value = 999; Historicals2.Add(new HistoricValue() { Name = newValue.Name, Lastdate = newValue.Lastdate, Value = newValue.Value }); Console.WriteLine("Expected output: Some Name 1,2,3 and 4"); Display(Historicals2); Console.WriteLine("\nOr through using a loop in wich a variable is created and assiged and then stops living."); for( int x = 5; x<= 7; x++) { HistoricValue newValueInLoop = new HistoricValue(); newValueInLoop.Name = "Some name " + x.ToString(); newValueInLoop.Lastdate = DateTime.Parse("2018-09-09"); newValueInLoop.Value = 999+x; Historicals2.Add(new HistoricValue() { Name = newValueInLoop.Name, Lastdate = newValueInLoop.Lastdate, Value = newValueInLoop.Value }); //Display(Historicals2); } Console.WriteLine("Expected output: Some Name 1,2,3,4,5,6,7"); Display(Historicals2); Console.WriteLine("Actually this is strange, realizing the variable only exists in the loop, yet the memory values are retainted, i hope the garbage collector works"); } } 中的数据集API make_one_shot_iterator()指定,我的后续工作是model.fit只能浏览一次数据集,但是在给出的解决方案中,指定了几个时期。


这是这些SO问题的后续行动

How to Properly Combine TensorFlow's Dataset API and Keras?

Tensorflow keras with tf dataset input

Using tf.data.Dataset as training input to Keras model NOT working

其中“从Tensorflow 1.9开始,可以将tf.data.Dataset对象直接传递到keras.Model.fit()中,其作用类似于fit_generator”。每个示例都有一个TF数据集,一个镜头迭代器输入到Kera的model.fit中。

下面是一个例子

make_one_shot_iterator()

但是,根据Tensorflow Dataset API指南(此处为https://www.tensorflow.org/guide/datasets):

  

单发迭代器是最简单的迭代器形式,   支持通过数据集迭代一次

因此仅适用于1个纪元。但是,SO问题中的代码指定了几个时期,上面的代码示例指定了5个时期。

对此矛盾有什么解释吗? Keras是否以某种方式知道,当单次迭代器遍历数据集时,它可以重新初始化并重新整理数据吗?

1 个答案:

答案 0 :(得分:1)

您只需将数据集对象传递给model.fit,Keras将处理迭代。 考虑一个预制数据集:

train, test = tf.keras.datasets.cifar10.load_data()
dataset = tf.data.Dataset.from_tensor_slices((train[0], train[1]))

这将从cifar10数据集的训练数据创建数据集对象。在这种情况下,不需要解析函数。 如果您从包含numpy数组列表图像的路径创建数据集,则将需要一个。

dataset = tf.data.Dataset.from_tensor_slices((image_path, labels_path)) 

如果您需要一个函数来从文件名加载实际数据。不用tf.read_file

就可以用相同的方式处理Numpy数组
def parse_func(filename):
    f = tf.read_file(filename)
    image = tf.image.decode_image(f)
    label = #get label from filename
    return image, label

然后,您可以将任何解析函数进行混洗,批处理和映射到该数据集。您可以控制将随机播放缓冲区预装入多少个示例。重复控件的纪元计数,最好不要设置为None,因此它将无限期地重复。您可以使用简单批处理功能,也可以与

结合使用
dataset = dataset.shuffle().repeat()
dataset.apply(tf.data.experimental.map_and_batch(map_func=parse_func, batch_size,num_parallel_batches))

然后可以将数据集对象传递给model.fit model.fit(数据集,纪元,steps_per_epoch)。请注意,在这种情况下,steps_per_epoch是必需的参数,它将定义何时开始新的纪元。因此,您必须提前知道纪元大小。