何时在Tensorflow Estimator中使用迭代器

时间:2018-10-31 03:30:55

标签: tensorflow tensorflow-datasets tensorflow-estimator

在Tensorflow指南中,该指南在两个单独的地方描述了虹膜数据示例的输入功能。一个输入函数仅返回数据集本身,而另一个输入函数返回带有迭代器的数据集。

摘自预制的估算器指南:https://www.tensorflow.org/guide/premade_estimators

def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

# Shuffle, repeat, and batch the examples.
return dataset.shuffle(1000).repeat().batch(batch_size)

来自自定义估算器指南:https://www.tensorflow.org/guide/custom_estimators

def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)

# Return the read end of the pipeline.
return dataset.make_one_shot_iterator().get_next()

我很困惑哪一个是正确的,如果它们都用于不同的情况,那么何时使用迭代器返回数据集是正确的?

1 个答案:

答案 0 :(得分:3)

如果输入函数返回var trailLength = 8 // The length of trail (8 by default; put more for longer "tail") var path = "http://www.javascriptkit.com/script/script2/cursor.gif" // URL of your image var standardbody=(document.compatMode=="CSS1Compat")? document.documentElement : document.body //create reference to common "body" across doctypes var i,d = 0 function initTrail() { // prepares the script images = new Array() // prepare the image array for (i = 0; i < parseInt(trailLength); i++) { images[i] = new Image() images[i].src = path } storage = new Array() // prepare the storage for the coordinates for (i = 0; i < images.length*3; i++) { storage[i] = 0 } for (i = 0; i < images.length; i++) { // make divs for IE and layers for Navigator document.write('<div id="obj' + i + '" style="position: absolute; z-Index: 100; height: 0; width: 0"><img src="' + images[i].src + '"></div>') } trail() } function trail() { // trailing function for (i = 0; i < images.length; i++) { // for every div/layer document.getElementById("obj" + i).style.top = storage[d]+'px' // the Y-coordinate document.getElementById("obj" + i).style.left = + storage[d+1]+'px' // the X-coordinate d = d+2 } for (i = storage.length; i >= 2; i--) { // save the coordinate for the div/layer that's behind storage[i] = storage[i-2] } d = 0 // reset for future use var timer = setTimeout("trail()",10) // call recursively } function processEvent(e) { // catches and processes the mousemove event if (window.event) { // for IE storage[0] = window.event.y+standardbody.scrollTop+10 storage[1] = window.event.x+standardbody.scrollLeft+10 } else { storage[0] = e.pageY+12 storage[1] = e.pageX+12 } } initTrail() document.onmousemove = processEvent // start capturing ,则会在幕后创建一个迭代器,并使用其tf.data.Dataset函数为模型提供输入。这在源代码中有些隐藏,请参见get_next() here

我相信这只是在最近的更新中实现的,因此较早的教程仍在其输入函数中显式返回parse_input_fn_result,因为这是当时的唯一选择。两者之间应该没有什么区别,但是您可以通过返回数据集而不是迭代器来节省一点代码。