我遇到了ValueError: Tensor("conv2d_1/kernel:0", ...) must be from the same graph as Tensor("IteratorGetNext:0", ...)
。我正在尝试重用具有Estimator
类的keras模型。
我尝试将所有可能的内容包含在
中 g = tf.Graph()
with g.as_default():
import tensorflow as tf
g = tf.Graph()
with g.as_default():
MODEL = get_keras_model(...)
def model_fn(mode, features, labels, params):
logits = MODEL(features)
...
def parser(record):
...
def get_dataset_inp_fn(filenames, epochs=20):
def dataset_input_fn():
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(parser)
...
with tf.Session(graph=g) as sess:
est = tf.estimator.Estimator(
model_fn,
model_dir=None,
config=None,
params={"optimizer": "AdamOptimizer",
"opt_params":{}}
)
est.train(get_dataset_inp_fn(["mydata.tfrecords"],epochs=20))
但这没有帮助。
有没有办法列出所有定义到当前点的图表?
答案 0 :(得分:1)
这是一种常规调试技术,将import pdb; pdb.set_trace()
放入tf.Graph
构造函数中,然后使用bt
找出创建图表的人员。我的第一个猜测是Keras不使用默认图并创建自己的图。您可以inspect.getsourcefile(tf.Graph)
查找Graph
文件位于本地的位置
答案 1 :(得分:0)
检查图形并返回错误的函数(希望它们返回图形地址)调用以下函数来检查图形:
from tensorflow.python.framework.ops import _get_graph_from_inputs
_get_graph_from_inputs([x])
在这种情况下,keras创建的图表与图表g
相同,但get_dataset_inp_fn
创建的图表与g
不同。