此问题与此SO post和此github discussion
有关我使用TFRecord存储我的数据,如果我对形状进行硬编码,代码就可以工作。但是,我需要从TFRecord获取形状。
def _parse_function(example_proto):
keys_to_features = {'image_size': tf.FixedLenFeature((), tf.int64),
'seq_steps': tf.FixedLenFeature((), tf.int64),
'K': tf.FixedLenFeature((), tf.int64),
'T': tf.FixedLenFeature((), tf.int64),
'input_seq': tf.FixedLenFeature((), tf.string)}
parsed_features = tf.parse_single_example(example_proto, keys_to_features)
image_size = tf.cast(parsed_features['image_size'],tf.int32)
seq_steps = tf.cast(parsed_features['seq_steps'],tf.int32)
K = tf.cast(parsed_features['K'],tf.int32)
T = tf.cast(parsed_features['T'],tf.int32)
input_shape = [image_size, image_size, seq_steps*(K+T), 2]
input_seq = tf.reshape(tf.decode_raw(parsed_features['input_seq'], tf.float32), tf.stack(input_shape), name='reshape_input_seq')
return input_seq
返回的张量具有动态形状,当通过模型时,在某个点调用tf.reshape然后返回一个错误,因为张量形状只是部分已知。
我不知道该怎么办。我尝试通过在返回之前设置input_seq.set_shape(tf.stack(input_shape))
来解决此问题。
然而,这给了我另一个错误TypeError: Tensor objects are not iterable when eager execution is not enabled. To iterate over this tensor use tf.map_fn.
要解决此问题,我尝试dataset = dataset.map(lambda elems: tf.map_fn(_parse_function, elems))
而非dataset = dataset.map(_parse_function)
然后返回ValueError: slice index 0 of dimension 0 out of bounds.
任何帮助将不胜感激!