tf.train.MonitoredTrainingSession遇到错误:TypeError:无法将feed_dict键解释为Tensor

时间:2019-03-06 06:14:47

标签: python tensorflow

我使用tf.train.MonitoredTrainingSession在分布式张量流中构建会话。 模型文件和主文件位于不同的文件中。在主文件中,默认图形已设置为model.graph,但是报告了代码

TypeError: Cannot interpret feed_dict key as Tensor: Tensor 
Tensor("Placeholder:0", dtype=float32) is not an element of this graph.

代码如下:

model.py

class model:
    def __init__(self):
        self.graph = tf.Graph()
        with self.graph.as_default():
            self.w = tf.Variable()
            self.wt_hldr = tf.placeholder(tf.float32,
                                 shape=[None, input_dim])
            self.lbl_hldr = tf.placeholder(tf.float32)
            self.logit=self.wt_hldr * self.w
            self.loss = tf.nn.sigmoid_cross_entropy_with_logits(
                            logits=logits, labels=self.lbl_hldr)

            update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            optimizer = tf.train.AdamOptimizer(0.01,0.000004)
            optimizer = hvd.DistributedOptimizer(optimizer)
            upde_opt = optimizer.minimizer(self.loss, global_step = global_step)

main.py

import tensorflow as tf
import horovod.tensorflow as hvd
from model import model

if__name__ == '__main__':
    hvd.init()
    model= model()
    graph = model.graph
    hooks= [hvd.BroadcastGlobalVariablesHook(0),
            tf.train.StopAtStepHook(last_step=20000 // hvd.size())]

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess_config.gpu_options.visible_device_list = str(hvd.local_rank())

    checkpoint_dir = 'model/' if hvd.rank() == 0 else None

    with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
                                           hooks=hooks,
                                           config=sess_config) as sess:
        global graph
        with graph.as_default():
            while not sess.should_stop():
                batch_wts, batch_labels=train_gen.next()
                loss=sess.run(model.loss, feed_dict=
                              {model.wt_hldr:batch_wts,
                               model.lbl_hldr:batch_labels})
                ....

0 个答案:

没有答案