如何在带有tensorflow的RNN中使用batch_scatter_update?

时间:2019-05-30 07:34:00

标签: python tensorflow

如何在具有tensorflow的RNN中使用batch_scatter_update?

我试图在从batch_scatter_update的{​​{1}}继承的类的步进函数中使用BasicDoder。但是,在运行时会发生错误。

源代码

tensorflow

错误消息

import tensorflow as tf
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.python.framework import ops
from tensorflow.python.layers import core as layers_core


class CopyDecoder(basic_decoder.BasicDecoder):
    def __init__(self, cell, helper, initial_state,
                 batch_size, output_size,
                 output_layer=None, internal_size=100):
        super().__init__(cell, helper, initial_state, output_layer)
        self.internal_size = internal_size
        self.output_size = output_size
        self.batch_size_ = batch_size
        self.W = layers_core.Dense(
            self.internal_size, name="W", use_bias=False)
        self.index = tf.tile(tf.expand_dims(tf.range(self.internal_size), axis=0), [batch_size, 1])


    def step(self, time, inputs, state, name=None):

        with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
            cell_outputs, cell_state =\
                self._cell(inputs, state)
            weight = self.W(cell_output) 

            def init_value():
                return tf.tile(tf.zeros([1, self.output_size],
                                        dtype=tf.float32),
                               [self.batch_size_, 1])
            ref = tf.Variable(
                    initial_value=init_value, validate_shape=False)
            scatter = ref.batch_scatter_update(
                tf.IndexedSlices(weight, self.index))
            cell_outputs = scatter

            sample_ids = self._helper.sample(
                time=time, outputs=cell_outputs, state=cell_state)
            (finished, next_inputs, next_state) = self._helper.next_inputs(
                time=time,
                outputs=cell_outputs,
                state=cell_state,
                sample_ids=sample_ids)
        outputs = basic_decoder.BasicDecoderOutput(cell_outputs, sample_ids)
        return (outputs, next_state, next_inputs, finished)

该错误似乎发生在任何地方,但是使用原始的Caused by op 'IteratorGetNext', defined at: File "code.py", line 33, in <module> model.train() File "/home/code/model.py", line 100, in train config=self.config) File "/home/code/reader.py", line 52, in __init__ self.output_tensors = self.compute_output() File "/home/code/reader.py", line 278, in compute_output return self.iterator.get_next() File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 414, in get_next output_shapes=self._structure._flat_shapes, name=name) File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/tensorflow/python/ops/gen_dataset_ops.py", line 1685, in iterator_get_next output_shapes=output_shapes, name=name) File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper op_def=op_def) File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func return func(*args, **kwargs) File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3300, in create_op op_def=op_def) File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1801, in __init__ self._traceback = tf_stack.extract_stack() FailedPreconditionError (see above for traceback): GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element. [[node IteratorGetNext (defined at /home/code/reader.py:278) ]] 可以正常工作。

请告诉我解决方法

0 个答案:

没有答案