我使用tf.train.MonitoredTrainingSession在分布式张量流中构建会话。 模型文件和主文件位于不同的文件中。在主文件中,默认图形已设置为model.graph,但是报告了代码
TypeError: Cannot interpret feed_dict key as Tensor: Tensor
Tensor("Placeholder:0", dtype=float32) is not an element of this graph.
代码如下:
model.py
class model:
def __init__(self):
self.graph = tf.Graph()
with self.graph.as_default():
self.w = tf.Variable()
self.wt_hldr = tf.placeholder(tf.float32,
shape=[None, input_dim])
self.lbl_hldr = tf.placeholder(tf.float32)
self.logit=self.wt_hldr * self.w
self.loss = tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits, labels=self.lbl_hldr)
update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS)
optimizer = tf.train.AdamOptimizer(0.01,0.000004)
optimizer = hvd.DistributedOptimizer(optimizer)
upde_opt = optimizer.minimizer(self.loss, global_step = global_step)
main.py
import tensorflow as tf
import horovod.tensorflow as hvd
from model import model
if__name__ == '__main__':
hvd.init()
model= model()
graph = model.graph
hooks= [hvd.BroadcastGlobalVariablesHook(0),
tf.train.StopAtStepHook(last_step=20000 // hvd.size())]
sess_config = tf.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess_config.gpu_options.visible_device_list = str(hvd.local_rank())
checkpoint_dir = 'model/' if hvd.rank() == 0 else None
with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
hooks=hooks,
config=sess_config) as sess:
global graph
with graph.as_default():
while not sess.should_stop():
batch_wts, batch_labels=train_gen.next()
loss=sess.run(model.loss, feed_dict=
{model.wt_hldr:batch_wts,
model.lbl_hldr:batch_labels})
....