我有两个tf.dataset,一个用于验证,一个用于训练。
我不时地想要切换数据源,以便我可以运行验证并检查其准确性。
This blog建议使用占位符并将正常的numpy数组提供给它。但这会破坏整个效率目标。 正如tf.data API api指南所述:
警告:“馈送”是将数据馈送到TensorFlow程序中效率最低的方法,仅应用于小型实验和调试。
所以,这是我想要实现的概念性示例:
# Load the datasets from tfrecord files:
val_dataset = tf.data.TFRecordDataset([val_recordfile])
train_dataset = tf.data.TFRecordDataset([train_recordfile])
## Batch size end shuffeling etc. here ##
iterator_tr = train_dataset.make_initializable_iterator()
iterator_val = val_dataset.make_initializable_iterator()
###############################################
## This is the magic: ##
it_op=tf.iterator_placeholder()
## tf.iterator_placeholder does not exist! ##
## and demonstrates my needs ##
###############################################
X, Y = it_op.get_next()
predictions=model(X)
train_op=train_and_minimize(X,Y)
acc_op=get_accuracy(Y,predictions)
with tf.Session() as sess:
# Initialize iterator here
accuracy_tr,_=sess.run([acc_op,train_op], feed_dict={it_op: iterator_tr})
accuracy_val=sess.run(acc_op, feed_dict={it_op: iterator_val})
它不是不是当然必须以这种确切方式完成!
我更愿意采用张量/意识形态的张量流方式,但是不需要输入原始数据的任何方式对我来说都是很棒的!
答案 0 :(得分:0)
事实证明,我的建议与可行的建议相差不远。实际上是在the Tensorflow's guide on datasets中提出的:
# Load the datasets in some form of tf.Dataset
tr_dataset=get_dataset(TRAINING)
val_dataset=get_dataset(VALIDATION)
# batching etc..
train_iterator = tr_dataset.make_initializable_iterator()
val_iterator = val_dataset.make_initializable_iterator()
# Make iterator handle that takes a string identifier
iterator_handle = tf.placeholder(tf.string, shape=[])
iterator=tf.data.Iterator.from_string_handle(iterator_handle, train_iterator.output_types,output_shapes=train_iterator.output_shapes)
with tf.Session() as sess:
# Create string handlers for the iterators
train_iterator_handle = sess.run(train_iterator.string_handle())
val_iterator_handle = sess.run(val_iterator.string_handle())
# Now initialize iterators
sess.run(train_iterator.initializer, feed_dict={iterator_handle: train_iterator_handle})
sess.run(val_iterator.initializer, feed_dict={iterator_handle: val_iterator_handle})