使用从TFRecords读取的值作为tf.reshape的参数

时间:2018-04-06 21:42:01

标签: python tensorflow

此问题与此SO post和此github discussion

有关

我使用TFRecord存储我的数据,如果我对形状进行硬编码,代码就可以工作。但是,我需要从TFRecord获取形状。

def _parse_function(example_proto):
    keys_to_features = {'image_size': tf.FixedLenFeature((), tf.int64),
                        'seq_steps': tf.FixedLenFeature((), tf.int64),
                        'K': tf.FixedLenFeature((), tf.int64),
                        'T': tf.FixedLenFeature((), tf.int64),
                        'input_seq': tf.FixedLenFeature((), tf.string)}

    parsed_features = tf.parse_single_example(example_proto, keys_to_features)

    image_size = tf.cast(parsed_features['image_size'],tf.int32)
    seq_steps = tf.cast(parsed_features['seq_steps'],tf.int32)
    K = tf.cast(parsed_features['K'],tf.int32)
    T = tf.cast(parsed_features['T'],tf.int32)

    input_shape = [image_size, image_size, seq_steps*(K+T), 2]

    input_seq = tf.reshape(tf.decode_raw(parsed_features['input_seq'], tf.float32), tf.stack(input_shape), name='reshape_input_seq')

    return input_seq

返回的张量具有动态形状,当通过模型时,在某个点调用tf.reshape然后返回一个错误,因为张量形状只是部分已知。

我不知道该怎么办。我尝试通过在返回之前设置input_seq.set_shape(tf.stack(input_shape))来解决此问题。

然而,这给了我另一个错误TypeError: Tensor objects are not iterable when eager execution is not enabled. To iterate over this tensor use tf.map_fn.

要解决此问题,我尝试dataset = dataset.map(lambda elems: tf.map_fn(_parse_function, elems))而非dataset = dataset.map(_parse_function)

然后返回ValueError: slice index 0 of dimension 0 out of bounds.

任何帮助将不胜感激!

0 个答案:

没有答案