如何在具有tensorflow的RNN中使用batch_scatter_update?
我试图在从batch_scatter_update
的{{1}}继承的类的步进函数中使用BasicDoder
。但是,在运行时会发生错误。
源代码
tensorflow
错误消息
import tensorflow as tf
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.python.framework import ops
from tensorflow.python.layers import core as layers_core
class CopyDecoder(basic_decoder.BasicDecoder):
def __init__(self, cell, helper, initial_state,
batch_size, output_size,
output_layer=None, internal_size=100):
super().__init__(cell, helper, initial_state, output_layer)
self.internal_size = internal_size
self.output_size = output_size
self.batch_size_ = batch_size
self.W = layers_core.Dense(
self.internal_size, name="W", use_bias=False)
self.index = tf.tile(tf.expand_dims(tf.range(self.internal_size), axis=0), [batch_size, 1])
def step(self, time, inputs, state, name=None):
with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
cell_outputs, cell_state =\
self._cell(inputs, state)
weight = self.W(cell_output)
def init_value():
return tf.tile(tf.zeros([1, self.output_size],
dtype=tf.float32),
[self.batch_size_, 1])
ref = tf.Variable(
initial_value=init_value, validate_shape=False)
scatter = ref.batch_scatter_update(
tf.IndexedSlices(weight, self.index))
cell_outputs = scatter
sample_ids = self._helper.sample(
time=time, outputs=cell_outputs, state=cell_state)
(finished, next_inputs, next_state) = self._helper.next_inputs(
time=time,
outputs=cell_outputs,
state=cell_state,
sample_ids=sample_ids)
outputs = basic_decoder.BasicDecoderOutput(cell_outputs, sample_ids)
return (outputs, next_state, next_inputs, finished)
该错误似乎发生在任何地方,但是使用原始的Caused by op 'IteratorGetNext', defined at:
File "code.py", line 33, in <module>
model.train()
File "/home/code/model.py", line 100, in train
config=self.config)
File "/home/code/reader.py", line 52, in __init__
self.output_tensors = self.compute_output()
File "/home/code/reader.py", line 278, in compute_output
return self.iterator.get_next()
File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 414, in get_next
output_shapes=self._structure._flat_shapes, name=name)
File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/tensorflow/python/ops/gen_dataset_ops.py", line 1685, in iterator_get_next
output_shapes=output_shapes, name=name)
File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
op_def=op_def)
File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3300, in create_op
op_def=op_def)
File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1801, in __init__
self._traceback = tf_stack.extract_stack()
FailedPreconditionError (see above for traceback): GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element.
[[node IteratorGetNext (defined at /home/code/reader.py:278) ]]
可以正常工作。
请告诉我解决方法