我正在尝试使用Tensorflow的tf.contrib.training.batch_sequences_with_states为RNN准备数据,特别是对变长序列进行批处理和分段。
下面的代码是一个玩具示例,我希望它能够无误地运行:
import os
import tensorflow as tf
BATCH_SIZE = 10
STATE_SIZE = 5
STEP_SIZE = 20
input_key = tf.placeholder(tf.string, [])
input_sequences = {
"inputs": tf.placeholder(tf.float32, [None]),
"labels": tf.placeholder(tf.float32, [None])
}
input_context = {
"length": tf.placeholder(tf.int32)
}
initial_states = {
"cell": tf.zeros([STATE_SIZE], tf.float32)
}
# https://www.tensorflow.org/api_docs/python/tf/contrib/training/batch_sequences_with_states
batch = tf.contrib.training.batch_sequences_with_states(
input_key=input_key,
input_sequences=input_sequences,
input_context=input_context,
input_length=None, # infer lengths
initial_states=initial_states,
num_unroll=STEP_SIZE,
batch_size=BATCH_SIZE
)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
tf.train.start_queue_runners(sess=sess)
我没有提供任何实际数据或运行任何东西(除了初始化变量)。我正在做的就是启动队列跑步者。我希望这个例子能够成功运行(尽管没有做任何事情)。但是,当我运行它时会抛出异常:
2017-08-08 01:06:00.789556: E tensorflow/core/common_runtime/executor.cc:644] Executor failed to create kernel. Invalid argument: Shape [-1] has negative dimensions
[[Node: input_sequences_inputs = Placeholder[dtype=DT_FLOAT, shape=[?], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
2017-08-08 01:06:00.796772: E tensorflow/core/common_runtime/executor.cc:644] Executor failed to create kernel. Invalid argument: Shape [-1] has negative dimensions
[[Node: input_sequences_inputs = Placeholder[dtype=DT_FLOAT, shape=[?], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
2017-08-08 01:06:00.798030: E tensorflow/core/common_runtime/executor.cc:644] Executor failed to create kernel. Invalid argument: Shape [-1] has negative dimensions
[[Node: input_sequences_inputs = Placeholder[dtype=DT_FLOAT, shape=[?], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
2017-08-08 01:06:00.805150: E tensorflow/core/common_runtime/executor.cc:644] Executor failed to create kernel. Invalid argument: Shape [-1] has negative dimensions
[[Node: input_sequences_inputs = Placeholder[dtype=DT_FLOAT, shape=[?], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
2017-08-08 01:06:00.810993: E tensorflow/core/common_runtime/executor.cc:644] Executor failed to create kernel. Invalid argument: Shape [-1] has negative dimensions
[[Node: input_sequences_inputs = Placeholder[dtype=DT_FLOAT, shape=[?], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
2017-08-08 01:06:00.812120: E tensorflow/core/common_runtime/executor.cc:644] Executor failed to create kernel. Invalid argument: Shape [-1] has negative dimensions
[[Node: input_sequences_inputs = Placeholder[dtype=DT_FLOAT, shape=[?], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
2017-08-08 01:06:00.818859: E tensorflow/core/common_runtime/executor.cc:644] Executor failed to create kernel. Invalid argument: Shape [-1] has negative dimensions
[[Node: input_sequences_inputs = Placeholder[dtype=DT_FLOAT, shape=[?], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
2017-08-08 01:06:00.822800: E tensorflow/core/common_runtime/executor.cc:644] Executor failed to create kernel. Invalid argument: Shape [-1] has negative dimensions
[[Node: input_sequences_inputs = Placeholder[dtype=DT_FLOAT, shape=[?], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
2017-08-08 01:06:00.823670: E tensorflow/core/common_runtime/executor.cc:644] Executor failed to create kernel. Invalid argument: Shape [-1] has negative dimensions
[[Node: input_sequences_inputs = Placeholder[dtype=DT_FLOAT, shape=[?], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
我的目的是提供input_sequences
作为占位符,并在培训期间提供实际值。序列将是可变长度的,因此我不想为此占位符定义固定大小,因此[None]
和input_sequences["inputs"]
的形状为input_sequences["labels"]
。
文档建议您可以提供input_length=None
,input_sequences["inputs"]
中序列的长度将从实际序列中推断出来:
input_length
是Tensor标量或在填充之前记录时间维度的int。它应该在0和时间维度之间。我们想要跟踪它的一个原因是我们可以在计算损失时将其考虑在内。如果pad=True
,那么input_length
可以是None
,并且会被推断出来。https://www.tensorflow.org/api_docs/python/tf/contrib/training/batch_sequences_with_states
我做错了什么?
我使用的是Python 3.6.1,Tensorflow 1.2.1。