我understand that there are advantages(特别是当我扩展我构建的模型的范围以及它们使用的数据集的大小时)使用TensorFlow的新Dataset
作为我的数据的习惯用法喂食管道。但是,我无法将现有的基于feed_dict
的代码映射到此新模型。
我面临的一个问题是,我无法理清批处理和时代的交互方式,或者这些问题如何与我经常进行的日志记录和验证交错。
例如,以下地图如何使用Dataset
?
# Load and process data into tensors of dimension (N, C_i) for input and (N, C_o) for output
# where N is the number of examples and C_ is the number of chanels, and the values are activations
train_x, train_y, valid_x, valid_y = load_data(file, [segments], ...)
train_size = len(train_x)
train_stats_feed = {input_activation: train_x, correct_output: train_y, is_train: False}
valid_stats_feed = {input_activation: valid_x, correct_output: valid_y, is_train: False}
with tf.Session(config=tf.ConfigProto(...)) as sess:
sess.run(tf.initialize_all_variables())
# Some analysis; not always done but the code needs to support it
train_writer.add_summary(sess.run(merged, feed_dict=train_stats_feed), 0)
test_writer.add_summary(sess.run(merged, feed_dict=valid_stats_feed), 0)
test_writer.add_summary(sess.run(gs_summary), 0)
print(log_fmt.format(0, float(sess.run(accuracy, feed_dict=valid_stats_feed)),
float(sess.run(loss, feed_dict=valid_stats_feed))))
for ep in range(epochs):
# Slice the training data into random batches
batch_indices = np.array_split(np.random.permutation(train_size), int(train_size/mb_size))
for mini_batch_indices in batch_indices:
sess.run(train_step, feed_dict={input_activation: train_x[mini_batch_indices],
correct_output: train_y[mini_batch_indices], is_train: True})
gs = int(sess.run(global_step))
if gs % log_steps == 0:
test_writer.add_summary(sess.run(merged, feed_dict=valid_stats_feed), gs)
train_writer.add_summary(sess.run(merged, feed_dict=train_stats_feed), gs)
acc = float(sess.run(accuracy, feed_dict=valid_stats_feed))
sess.run(validation_accuracy.assign(acc))
print(log_fmt.format(gs, acc, float(sess.run(loss, feed_dict=valid_stats_feed))))
print(ep_fmt.format(ep + 2))
test_writer.add_summary(sess.run(gs_summary), ep + 1)
如果需要,上面的一些不太明显的定义:
# Preliminaries
# Some basic preliminaries, the details of which are not important to the question
# Mostly pretty standard; obvious things omitted from MWE for brevity
global_step = tf.Variable(0, trainable=False, name='global_step')
validation_accuracy = tf.Variable(0.0, trainable=False, name='validation_accuracy', dtype=tf.float32)
is_train = tf.placeholder(tf.bool, [], name='is_train')
input_activation = tf.placeholder(tf.float32, shape=[None, in_nodes], name='inputs')
correct_output = tf.placeholder(tf.float32, shape=[None, out_nodes], name='correct_outputs')
network_output = tf.identity(out_activations)
correct_predictions = correct_fn(correct_output, network_output)
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))
error = cost_fn(correct_output, network_output)
loss = error + FLAGS.regularization_weight * sum(tf.nn.l2_loss(w) for w in layer_weights)
train_step = tf.train.MomentumOptimizer(learning_rate, momentum=momentum).minimize(loss, global_step=global_step)
# Logging
train_writer = tf.summary.FileWriter(trainlogfile, tf.get_default_graph())
test_writer = tf.summary.FileWriter(testlogfile, tf.get_default_graph())
gs_summary = tf.summary.scalar('global_step_at_epoch', global_step)
merged = tf.summary.merge_all()
答案 0 :(得分:-1)
这里有很少的培训课程。相同的逻辑适用于验证
# Define placeholder for inputs data and labels
inputs_placeholder = tf.placeholder(train_x.dtype, train_x.shape)
labels_placeholder = tf.placeholder(train_y.dtype, train_y.shape)
# Define a Dataset object using the above placeholders
dataset = tf.contrib.data.Dataset.from_tensor_slices((inputs_placeholder, labels_placeholder))
# Define batch_size
batch_size = 128
dataset = dataset.batch(batch_size)
# Define iterator
iterator = dataset.make_initializable_iterator()
# Get one batch
next_example, next_label = iterator.get_next()
# calculate loss from the model fucntion you are using
loss = some_model(next_example, next_label)
# Set number of Epochs here
num_epochs = 100
for _ in range(num_epochs):
sess.run(iterator.initializer, feed_dict={inputs_placeholder: train_x, labels_placeholder: train_y}))
while True:
try:
_loss = sess.run(loss)
except tf.errors.OutOfRangeError:
break