如何使用tf.data.Dataset实现Tensorflow的defcrated tf.range_input_producer函数?

时间:2019-02-16 21:08:05

标签: python tensorflow

我正在尝试使用Tensorflow基于Jupyter Notebook的官方网站的RNN教程来实现单词嵌入。

基于repo,一个名为ptb_producer的函数负责将文本数据(已转换为id)转换为批处理,并输出输入和标签的元组。在该函数中,张量流的range_input_producer负责在队列数上以纪元数输出从0limit-1的整数。

epoch_size = (batch_len - 1) // num_steps
    assertion = tf.assert_positive(
        epoch_size,
        message="epoch_size == 0, decrease batch_size or num_steps")
    with tf.control_dependencies([assertion]):
      epoch_size = tf.identity(epoch_size, name="epoch_size")

    i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()

但是,如果尝试将ptb_producer函数的输出集放入一个变量元组,则取元组的元素之一(假设输入),并获得该元素的嵌入(输入) ,该过程在Jupyter Notebook中暂停。

inputs, targets = ptb_producer(train_data, 32, 1, name="TrainStep")

with tf.device("/cpu:0"):
    embeddings = tf.get_variable("embeddings", [10000, 300], dtype=tf.float32)
    new_inputs = tf.nn.embedding_lookup(embeddings, inputs)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
sess.run(new_inputs)

当我尝试运行sess.run(new_inputs)时,笔记本计算机暂停,并且没有任何内容输出到笔记本计算机。我也尝试运行sess.run((init, new_inputs)),但笔记本电脑也暂停了。

Tensorflow的警告指出tf.train.range_input_producer已过时,最好使用tf.data.Dataset.range(limit).shuffle(limit).repeat(num_epochs). If shuffle=False, omit the .shuffle(...)

我希望它会返回相同的输入和目标元组。但是当我尝试时,似乎tf.data.Dataset不会返回整数。

epoch_size = (batch_len - 1) // num_steps
    assertion = tf.assert_positive(
        epoch_size,
        message="epoch_size == 0, decrease batch_size or num_steps")
    with tf.control_dependencies([assertion]):
      epoch_size = tf.identity(epoch_size, name="epoch_size")
      epoch_size = tf.cast(epoch_size, tf.int64)

    i = tf.data.Dataset.range(epoch_size).repeat()

输出

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-12-2007122212cd> in <module>
----> 1 inputs, targets = ptb_producer(train_data[:1000], 32, 1, name="TrainStep")

<ipython-input-11-2a2491970140> in ptb_producer(raw_data, batch_size, num_steps, name)
     31 
     32     i = tf.data.Dataset.range(epoch_size).repeat()
---> 33     x = tf.strided_slice(data, [0, i * num_steps],
     34                          [batch_size, (i + 1) * num_steps])
     35     x.set_shape([batch_size, num_steps])

TypeError: unsupported operand type(s) for *: 'RepeatDataset' and 'int'

再次,我的尝试是能够看到每个单词的嵌入。我想知道我是否错过了什么?

0 个答案:

没有答案