我试图从tf.Session()
切换到tf.train.MonitoredTrainingSession
(所有在一台计算机上,没有花哨的分布式计算),但我收到的错误是我没有#&# 39;完全明白。
W tensorflow/core/framework/op_kernel.cc:1148] Invalid argument: Shape [16,-1,4] has negative dimensions
E tensorflow/core/common_runtime/executor.cc:644] Executor failed to create kernel. Invalid argument: Shape [16,-1,4] has negative dimensions
[[Node: define_inputs/Placeholder = Placeholder[dtype=DT_FLOAT, shape=[16,?,4], _device="/job:local/replica:0/task:0/cpu:0"]()]]
再向下,我会收到有关错误的更多信息:
Traceback (most recent call last):
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1139, in _do_call
return fn(*args)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1121, in _run_fn
status, run_metadata)
File "/usr/local/Cellar/python3/3.6.1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/contextlib.py", line 89, in __exit__
next(self.gen)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
pywrap_tensorflow.TF_GetCode(status))
我使用tf.contrib.seq2seq
并且我的输入和输出序列具有可变长度,例如x_placeholder = tf.placeholder(tf.float32, [batch_size, None, 4])
。
我怀疑我用于按序列长度读取数据和存储区数据的队列在某种程度上失败或被MonitoredTrainingSession
打断,因为我没有遇到这个问题香草Session
。
以下是设置MonitoredTrainingSession
# create a global step
global_step = tf.contrib.framework.get_or_create_global_step()
# define graph
model = import_model(global_step)
# create a one process cluster with an in-process server
server = tf.train.Server.create_local_server()
# define hooks for writing summaries and model variables to disk
hooks = construct_training_hooks(model.summary_op)
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=True,
hooks=hooks) as monitored_sess:
# create coordinator to handle threading
coord = tf.train.Coordinator()
# start threads to enqueue input minibatches for training
threads = tf.train.start_queue_runners(sess=monitored_sess, coord=coord)
# train
while not monitored_sess.should_stop():
train_op(monitored_sess, model, x_train, y_train, y_lengths_train)
# when done, ask the threads to stop
coord.request_stop()
# wait for threads to finish
coord.join(threads)
以下是我创建培训挂钩的方法:
def construct_training_hooks(summary_op):
hooks = [tf.train.StopAtStepHook(last_step=tf.flags.FLAGS.training_steps),
tf.train.CheckpointSaverHook(checkpoint_dir=tf.flags.FLAGS.log_dir,
saver=tf.train.Saver(),
save_steps=10),
tf.train.SummarySaverHook(output_dir=tf.flags.FLAGS.log_dir,
summary_op=summary_op,
save_steps=10)]
return hooks