Tensorflow出队操作未在set_shape之后执行

时间:2018-07-16 13:12:29

标签: python tensorflow queue shape

我尝试使用队列接口来训练张量流模型。我使用下一个代码片段

import os
from datetime import datetime
import threading

import numpy as np
import tensorflow as tf

from tf_seq2seq_df import Graph
from utils import *
from model_config import Parameters as p


tf.set_random_seed(1)
np.random.seed(1)


def train(model_path, model_name, num_steps, tune=False):

    coord = tf.train.Coordinator()
    b_size = p.batch_size
    placeholders = [tf.placeholder(tf.int32, shape=(b_size, None)),
                    tf.placeholder(tf.float32, shape=(b_size, None, p.num_mels)),
                    tf.placeholder(tf.float32, shape=(b_size, None, p.num_lins)),
                    tf.placeholder(tf.float32, shape=(b_size, None))]

    queue = tf.FIFOQueue(8, [tf.int32, tf.float32, tf.float32, tf.float32])
    enqueue_op = queue.enqueue(placeholders)
    inputs, mel_targets, linear_targets, stop_targets = queue.dequeue()
    inputs.set_shape(placeholders[0].shape)
    mel_targets.set_shape(placeholders[1].shape)
    linear_targets.set_shape(placeholders[2].shape)
    stop_targets.set_shape(placeholders[3].shape)

    graph = Graph(inputs, mel_targets, linear_targets, stop_targets, b_size, 'train')
    graph.add_loss()
    graph.add_optimizer()
    graph.add_summary()

    saver = tf.train.Saver()

    with tf.Session() as sess:

        if tune:
            saver.restore(sess, tf.train.latest_checkpoint(model_path))
        else:
            sess.run(tf.global_variables_initializer())
        fw = tf.summary.FileWriter(p.logdir)

        def enqueue_thread():
            with coord.stop_on_exception():
                while not coord.should_stop():
                    batch_data = Batch(b_size)
                    while True:
                        try:
                            next_batch = next(batch_data)
                            sess.run(enqueue_op, feed_dict={placeholders: next_batch})
                        except StopIteration:
                            break

        threading.Thread(target=enqueue_thread).start()

        sess.run(inputs)


if __name__ == '__main__':
    train('models\\model2', 'model2', 10, tune=False)

我有enqueue_thread()函数来填充队列。而且效果很好。

但是我需要在图形初始化之前设置变量的形状。因此,我在出队操作后添加set_shape。这是一个问题。因为使用set_shape脚本会挂断。队列未装满,出队操作未执行。

我阅读了很多有关此主题的文档和讨论,但是不知道如何解决它。请以什么方式可以更改此脚本给我建议。预先感谢。

0 个答案:

没有答案