在函数内运行TensorFlow操作

时间:2017-03-21 23:19:22

标签: tensorflow

以下是我程序中的一项功能:

def addfilenames (train_image_dict):
   with tf.name_scope(values=[train_image_dict], name="AddFileNames/"):
        print("Hello")
        filename_queue = tf.RandomShuffleQueue(capacity=len(trainimgs), min_after_dequeu\
e=0,\
                                               dtypes=[tf.string], names=["ImageFile"],\
                                           seed=0, name="filename_queue")

        enq_op = filename_queue.enqueue_many(train_image_dict)
    with tf.variable_scope("main_scope") as scope:
            try:
                epoch = tf.get_variable(name="epoch", shape=[1],\
                                    initializer=tf.zeros_initializer())
            except ValueError:
                scope.reuse_variables()
                epoch = tf.get_variable("epoch")
# I want to increment epoch here 
   return filename_queue, enq_op

我有一个主要功能如下:

if __name__ == "__main__":
    g, drop2 = OverFeatAccurate()
    trainimgs, trainlbls,  classdict = ReadTrain('/local/ujjwal/ILSVRC2015/Data/CLS-LOC/\
train')
    with g.as_default():
        trainimgs_tensor = tf.constant(trainimgs)
        trainimgs_dict = {}
        trainimgs_dict["ImageFile"] = trainimgs_tensor
        filename_q, filename_enqueue_op= addfilenames(trainimgs_dict)

        qr = tf.train.QueueRunner(filename_q, [filename_enqueue_op])
        filename_dequeue_op = filename_q.dequeue()
        init_op = tf.global_variables_initializer()

    sess = tf.Session(graph=g)
    sess.run(init_op)
    coord = tf.train.Coordinator()
    enq_threads = qr.create_threads(sess, coord=coord, start=True)
    counter = 0
    for step in range(100):
        print(sess.run(filename_dequeue_op["ImageFile"]))
        print("Epoch = %d "%(epoch))
counter+=1

    names = [n.name for n in g.as_graph_def().node]

    coord.request_stop()
    coord.join(enq_threads)
    print("Counter = %d"%(counter))

我想在完成epoch功能之前增加addfilenames Tensor。虽然我可以从中返回增量op,但由于它必须在线程上下文中使用,我希望增量发生在addfilenames函数内。我无法将tf.Session()对象传递给函数,因为稍后会调用tf.Session

如果我在tf.Session() as sess内使用addfilenames,我必须再次初始化所有变量。

addfilenames函数中运行增量操作的正确方法是什么?

0 个答案:

没有答案