tensorboard无法使用tf.data.Dataset显示图形(解析graph.pbtxt)

时间:2018-02-27 11:30:35

标签: python tensorflow tensorboard

它的全部都在标题中:tensorboard在一个相对适度的模型上解析graph.pbtxt,但仅限于使用tf.data.Dataset输入时,而不是使用fee_dict输入。

这是主要功能,有一个tf.summary.FileWriter(somepath, sess.graph)的调用,但图表无处可寻:

def train_input_fn(batch_size):
    with tf.name_scope("train_dataset"):
        features, labels = dataset.train_x[:50000], dataset.train_y[:50000]
        features = features.reshape((-1, 28, 28, 1)).astype(np.float32)
        features = (features / 255 - 0.1307) / 0.3081
        labels = labels.reshape((-1,)).astype(np.int32)
        pairs = tf.data.Dataset.from_tensor_slices((
            features, labels))
        pairs = pairs.shuffle(len(labels)).batch(batch_size).repeat()
        return pairs.make_one_shot_iterator().get_next()
batch_size = 64

# input data loading
train_inputs, train_labels = train_input_fn(batch_size)


# training graph
with tf.name_scope("inference"):
    train_logits = build_model_layers(train_inputs, 10, True)
    train_loss = tf.losses.sparse_softmax_cross_entropy(
        train_labels, train_logits)
    # train_loss_summ = tf.summary.scalar('loss', train_loss)

    optimizer = tf.train.AdagradOptimizer(learning_rate=0.01)
    train_op = optimizer.minimize(
        train_loss, global_step=tf.train.get_global_step())

saver = tf.train.Saver()

with tf.Session() as sess:
    # variable initialization
    sess.run(tf.global_variables_initializer())

    # summaries setup
    model_dir = sensible_dir(
        "experiments/SpatialTransformerNetwork/checkpoints", "run_")
    train_writer = tf.summary.FileWriter(
        model_dir + "/train", graph=sess.graph)

    sess.run(tf.local_variables_initializer())

    # Run
    for step in range(50000 // batch_size * 1):
        sess.run([train_op])

    saver.save(sess, model_dir + "/model.ckpt")

1 个答案:

答案 0 :(得分:1)

正如你所看到的,我的加载函数使用了tf.data.Dataset.from_tensor_slices,并且埋藏在张量流API文档的浩瀚中的某个地方就是说它会将该数组作为常量加载到图中,后者因此变得太大而无法装载......