我尝试使用队列接口来训练张量流模型。我使用下一个代码片段
import os
from datetime import datetime
import threading
import numpy as np
import tensorflow as tf
from tf_seq2seq_df import Graph
from utils import *
from model_config import Parameters as p
tf.set_random_seed(1)
np.random.seed(1)
def train(model_path, model_name, num_steps, tune=False):
coord = tf.train.Coordinator()
b_size = p.batch_size
placeholders = [tf.placeholder(tf.int32, shape=(b_size, None)),
tf.placeholder(tf.float32, shape=(b_size, None, p.num_mels)),
tf.placeholder(tf.float32, shape=(b_size, None, p.num_lins)),
tf.placeholder(tf.float32, shape=(b_size, None))]
queue = tf.FIFOQueue(8, [tf.int32, tf.float32, tf.float32, tf.float32])
enqueue_op = queue.enqueue(placeholders)
inputs, mel_targets, linear_targets, stop_targets = queue.dequeue()
inputs.set_shape(placeholders[0].shape)
mel_targets.set_shape(placeholders[1].shape)
linear_targets.set_shape(placeholders[2].shape)
stop_targets.set_shape(placeholders[3].shape)
graph = Graph(inputs, mel_targets, linear_targets, stop_targets, b_size, 'train')
graph.add_loss()
graph.add_optimizer()
graph.add_summary()
saver = tf.train.Saver()
with tf.Session() as sess:
if tune:
saver.restore(sess, tf.train.latest_checkpoint(model_path))
else:
sess.run(tf.global_variables_initializer())
fw = tf.summary.FileWriter(p.logdir)
def enqueue_thread():
with coord.stop_on_exception():
while not coord.should_stop():
batch_data = Batch(b_size)
while True:
try:
next_batch = next(batch_data)
sess.run(enqueue_op, feed_dict={placeholders: next_batch})
except StopIteration:
break
threading.Thread(target=enqueue_thread).start()
sess.run(inputs)
if __name__ == '__main__':
train('models\\model2', 'model2', 10, tune=False)
我有enqueue_thread()函数来填充队列。而且效果很好。
但是我需要在图形初始化之前设置变量的形状。因此,我在出队操作后添加set_shape。这是一个问题。因为使用set_shape脚本会挂断。队列未装满,出队操作未执行。
我阅读了很多有关此主题的文档和讨论,但是不知道如何解决它。请以什么方式可以更改此脚本给我建议。预先感谢。