在运行时在Tensorflow中更改tf.dataset源

时间:2019-02-27 09:40:46

标签: python tensorflow tensorflow-datasets

我有两个tf.dataset,一个用于验证,一个用于训练。

我不时地想要切换数据源,以便我可以运行验证并检查其准确性。

This blog建议使用占位符并将正常的numpy数组提供给它。但这会破坏整个效率目标。 正如tf.data API api指南所述:

  

警告:“馈送”是将数据馈送到TensorFlow程序中效率最低的方法,仅应用于小型实验和调试。

所以,这是我想要实现的概念性示例:

# Load the datasets from tfrecord files:
val_dataset = tf.data.TFRecordDataset([val_recordfile])
train_dataset = tf.data.TFRecordDataset([train_recordfile])

##  Batch size end shuffeling etc. here  ##

iterator_tr = train_dataset.make_initializable_iterator()
iterator_val = val_dataset.make_initializable_iterator()

###############################################
##  This is the magic:                       ##
it_op=tf.iterator_placeholder()
##  tf.iterator_placeholder does not exist!  ##
##  and demonstrates my needs                ##
###############################################
X, Y = it_op.get_next()
predictions=model(X)
train_op=train_and_minimize(X,Y)
acc_op=get_accuracy(Y,predictions)

with tf.Session() as sess:
    # Initialize iterator here
    accuracy_tr,_=sess.run([acc_op,train_op], feed_dict={it_op: iterator_tr})
    accuracy_val=sess.run(acc_op, feed_dict={it_op: iterator_val})

它不是不是当然必须以这种确切方式完成!

我更愿意采用张量/意识形态的张量流方式,但是不需要输入原始数据的任何方式对我来说都是很棒的!

1 个答案:

答案 0 :(得分:0)

事实证明,我的建议与可行的建议相差不远。实际上是在the Tensorflow's guide on datasets中提出的:

# Load the datasets in some form of tf.Dataset
tr_dataset=get_dataset(TRAINING)
val_dataset=get_dataset(VALIDATION)
# batching etc..
train_iterator = tr_dataset.make_initializable_iterator()
val_iterator = val_dataset.make_initializable_iterator()
# Make iterator handle that takes a string identifier
iterator_handle = tf.placeholder(tf.string, shape=[])
iterator=tf.data.Iterator.from_string_handle(iterator_handle, train_iterator.output_types,output_shapes=train_iterator.output_shapes)
with tf.Session() as sess:
    # Create string handlers for the iterators
    train_iterator_handle = sess.run(train_iterator.string_handle())
    val_iterator_handle = sess.run(val_iterator.string_handle())
    # Now initialize iterators
    sess.run(train_iterator.initializer, feed_dict={iterator_handle: train_iterator_handle})
    sess.run(val_iterator.initializer, feed_dict={iterator_handle: val_iterator_handle})