我正在使用带有占位符的TF Dataset API来存储在初始化迭代器时输入的文件名(不同的文件,取决于它是训练集还是验证集)。我还想使用其他占位符,指示我们是在培训还是在验证(包括在辍学层中)。但是,我无法使用数据集初始化程序将值提供给此占位符(这很有意义,因为这不是数据集的一部分)。然后在使用Dataset API时如何提供其他变量?
关键代码段:
filenames_placeholder = tf.placeholder(tf.string, shape = (None))
is_training = tf.placeholder(tf.bool, shape = ()) # Error: You must feed a value for placeholder tensor 'Placeholder_1' with dtype bool
dataset = tf.data.TFRecordDataset(filenames_placeholder)
# (...) Many other dataset operations
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Model code using "next_element" as inputs including the dropout layer at some point
# where I would like to let the model know if we're training or validating
tf.layers.dropout(x, training = is_training)
# Model execution
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(iterator.initializer, feed_dict = {filenames_placeholder: training_files, is_training: True})
# (...) Performing training
sess.run(iterator.initializer, feed_dict = {filenames_placeholder: training_files, is_training: False})
# (...) Performing validadtion
答案 0 :(得分:1)
在这种情况下,我要做的是使用默认值的附加占位符:
keep_prob = tf.placeholder_with_default(1.0, shape=())
在图中:
tf.layers.dropout(inputs, rate=1-keep_prob)
然后在训练时:
sess.run(...,feed_dict={keep_prob:0.5})
在评估时:
sess.run(...) # No feed_dict here since the keep_prob placeholder has a default value of 1
请注意,在训练过程中喂食占位符会提供额外的float
值,根本不会减慢您的训练速度。