如何检查训练或测试数据是否已初始化?

时间:2018-12-18 12:35:39

标签: tensorflow tensorflow-datasets

我正在使用带有切换机制的TensorFlow数据集API在训练和测试集之间进行切换。

dataset_iter = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)

features, labels = dataset_iter.get_next()

train_init_op = dataset_iter.make_initializer(train_dataset)
test_init_op = dataset_iter.make_initializer(test_dataset)

featureslabels用于图形,例如:

logits = tf.layers.dense(features, units=dataset.labels.shape[-1])
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)

对于每个用于测试和训练的时期,通过调用各自的初始化器(train_init_optest_init_op)来切换数据集。

现在,我想使用辍学层,但我不知道如何确定如何检查当前跑步训练或测试集是否已初始化:

is_training = ???
net = tf.layers.dropout(net, rate=0.25, training=is_training)

is_training必须是变量,并且不应在图形建立时间上进行评估。是否应该在每次运行中进行评估。

如何执行此操作?我不想重新定义图表进行测试或培训。

2 个答案:

答案 0 :(得分:0)

好的,我已经想出了一个解决方案:

is_training = tf.Variable(False, dtype=tf.bool)
train_init_op = tf.group(dataset_iter.make_initializer(train_dataset), tf.assign(is_training, True))
test_init_op = tf.group(dataset_iter.make_initializer(test_dataset), tf.assign(is_training, False))

我添加了一个跟踪状态(培训/测试)的附加变量。 调用初始化程序并将其设置为正确的值时,也会调用此变量。

我希望有一个集成的/开箱即用的版本。

如果有人知道这样的解决方案,请随时提供其他答案。

答案 1 :(得分:0)

也许我们可以尝试tf.control_dependencies

with tf.control_dependencies([train_init_op]):
    net = tf.layers.dropout(net, rate=0.25, training=True)