在tensorflow中列出可用的图表

时间:2017-09-08 23:59:56

标签: tensorflow

我遇到了ValueError: Tensor("conv2d_1/kernel:0", ...) must be from the same graph as Tensor("IteratorGetNext:0", ...)。我正在尝试重用具有Estimator类的keras模型。

我尝试将所有可能的内容包含在

g = tf.Graph() with g.as_default():

import tensorflow as tf

g = tf.Graph()
with g.as_default():
    MODEL = get_keras_model(...)

    def model_fn(mode, features, labels, params):
        logits = MODEL(features)
        ...

    def parser(record):
        ...
    def get_dataset_inp_fn(filenames, epochs=20):
            def dataset_input_fn():
                dataset = tf.contrib.data.TFRecordDataset(filenames)
                dataset = dataset.map(parser)
                ...

with tf.Session(graph=g) as sess:
    est = tf.estimator.Estimator(
            model_fn,
            model_dir=None,
            config=None,
            params={"optimizer": "AdamOptimizer",
                    "opt_params":{}}
            )
    est.train(get_dataset_inp_fn(["mydata.tfrecords"],epochs=20))

但这没有帮助。

有没有办法列出所有定义到当前点的图表?

2 个答案:

答案 0 :(得分:1)

这是一种常规调试技术,将import pdb; pdb.set_trace()放入tf.Graph构造函数中,然后使用bt找出创建图表的人员。我的第一个猜测是Keras不使用默认图并创建自己的图。您可以inspect.getsourcefile(tf.Graph)查找Graph文件位于本地的位置

答案 1 :(得分:0)

检查图形并返回错误的函数(希望它们返回图形地址)调用以下函数来检查图形:

from tensorflow.python.framework.ops import _get_graph_from_inputs
_get_graph_from_inputs([x])

在这种情况下,keras创建的图表与图表g相同,但get_dataset_inp_fn创建的图表与g不同。