我正在尝试使用Tensorflow基于Jupyter Notebook的官方网站的RNN教程来实现单词嵌入。
基于repo,一个名为ptb_producer
的函数负责将文本数据(已转换为id)转换为批处理,并输出输入和标签的元组。在该函数中,张量流的range_input_producer
负责在队列数上以纪元数输出从0
到limit-1
的整数。
epoch_size = (batch_len - 1) // num_steps
assertion = tf.assert_positive(
epoch_size,
message="epoch_size == 0, decrease batch_size or num_steps")
with tf.control_dependencies([assertion]):
epoch_size = tf.identity(epoch_size, name="epoch_size")
i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
但是,如果尝试将ptb_producer
函数的输出集放入一个变量元组,则取元组的元素之一(假设输入),并获得该元素的嵌入(输入) ,该过程在Jupyter Notebook中暂停。
inputs, targets = ptb_producer(train_data, 32, 1, name="TrainStep")
with tf.device("/cpu:0"):
embeddings = tf.get_variable("embeddings", [10000, 300], dtype=tf.float32)
new_inputs = tf.nn.embedding_lookup(embeddings, inputs)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
sess.run(new_inputs)
当我尝试运行sess.run(new_inputs)
时,笔记本计算机暂停,并且没有任何内容输出到笔记本计算机。我也尝试运行sess.run((init, new_inputs))
,但笔记本电脑也暂停了。
Tensorflow的警告指出tf.train.range_input_producer
已过时,最好使用tf.data.Dataset.range(limit).shuffle(limit).repeat(num_epochs). If shuffle=False, omit the .shuffle(...)
。
我希望它会返回相同的输入和目标元组。但是当我尝试时,似乎tf.data.Dataset
不会返回整数。
epoch_size = (batch_len - 1) // num_steps
assertion = tf.assert_positive(
epoch_size,
message="epoch_size == 0, decrease batch_size or num_steps")
with tf.control_dependencies([assertion]):
epoch_size = tf.identity(epoch_size, name="epoch_size")
epoch_size = tf.cast(epoch_size, tf.int64)
i = tf.data.Dataset.range(epoch_size).repeat()
输出
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-12-2007122212cd> in <module>
----> 1 inputs, targets = ptb_producer(train_data[:1000], 32, 1, name="TrainStep")
<ipython-input-11-2a2491970140> in ptb_producer(raw_data, batch_size, num_steps, name)
31
32 i = tf.data.Dataset.range(epoch_size).repeat()
---> 33 x = tf.strided_slice(data, [0, i * num_steps],
34 [batch_size, (i + 1) * num_steps])
35 x.set_shape([batch_size, num_steps])
TypeError: unsupported operand type(s) for *: 'RepeatDataset' and 'int'
再次,我的尝试是能够看到每个单词的嵌入。我想知道我是否错过了什么?