它的全部都在标题中:tensorboard在一个相对适度的模型上解析graph.pbtxt,但仅限于使用tf.data.Dataset输入时,而不是使用fee_dict输入。
这是主要功能,有一个tf.summary.FileWriter(somepath, sess.graph)
的调用,但图表无处可寻:
def train_input_fn(batch_size):
with tf.name_scope("train_dataset"):
features, labels = dataset.train_x[:50000], dataset.train_y[:50000]
features = features.reshape((-1, 28, 28, 1)).astype(np.float32)
features = (features / 255 - 0.1307) / 0.3081
labels = labels.reshape((-1,)).astype(np.int32)
pairs = tf.data.Dataset.from_tensor_slices((
features, labels))
pairs = pairs.shuffle(len(labels)).batch(batch_size).repeat()
return pairs.make_one_shot_iterator().get_next()
batch_size = 64
# input data loading
train_inputs, train_labels = train_input_fn(batch_size)
# training graph
with tf.name_scope("inference"):
train_logits = build_model_layers(train_inputs, 10, True)
train_loss = tf.losses.sparse_softmax_cross_entropy(
train_labels, train_logits)
# train_loss_summ = tf.summary.scalar('loss', train_loss)
optimizer = tf.train.AdagradOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(
train_loss, global_step=tf.train.get_global_step())
saver = tf.train.Saver()
with tf.Session() as sess:
# variable initialization
sess.run(tf.global_variables_initializer())
# summaries setup
model_dir = sensible_dir(
"experiments/SpatialTransformerNetwork/checkpoints", "run_")
train_writer = tf.summary.FileWriter(
model_dir + "/train", graph=sess.graph)
sess.run(tf.local_variables_initializer())
# Run
for step in range(50000 // batch_size * 1):
sess.run([train_op])
saver.save(sess, model_dir + "/model.ckpt")
答案 0 :(得分:1)
正如你所看到的,我的加载函数使用了tf.data.Dataset.from_tensor_slices,并且埋藏在张量流API文档的浩瀚中的某个地方就是说它会将该数组作为常量加载到图中,后者因此变得太大而无法装载......