输入给DecodeRaw的不是8的倍数,不是双倍的大小

时间:2019-04-23 19:08:12

标签: python tensorflow tensorflow-datasets

我想从tfrecords建立一个tensorflow数据集。 这是我的代码:

def make_dataset():
   filenames = [train_tfrecords_dir + name for name in os.listdir(train_tfrecords_dir)] 
   dataset = tf.data.TFRecordDataset(filenames)

    def parser(record):
         keys_to_features = {
        "mhot_label_raw": tf.FixedLenFeature((), tf.string, default_value=""),
        "mel_spec_raw": tf.FixedLenFeature((), tf.string, default_value=""),
    }
        parsed = tf.parse_single_example(record, keys_to_features)

        mel_spec1d = tf.decode_raw(parsed['mel_spec_raw'], tf.float64)
        mhot_label = tf.decode_raw(parsed['mhot_label_raw'], tf.float64)
        mel_spec = tf.reshape(mel_spec1d, [30, 65,85])
        return {"mel_data": mel_spec}, mhot_label

   dataset = dataset.map(parser)
   dataset = dataset.repeat(num_epochs)
   dataset = dataset.batch(batch_size)
   iterator = dataset.make_one_shot_iterator()
   return iterator

但这会导致此错误:

 InvalidArgumentError: Input to DecodeRaw has length 165750 that is not a multiple of 8, the size of double
     [[Node: DecodeRaw = DecodeRaw[little_endian=true, out_type=DT_DOUBLE](ParseSingleExample/Squeeze_mel_spec_raw)]]
     [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,30,65,85], [?,?]], output_types=[DT_DOUBLE, DT_DOUBLE], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

我该如何解决?我删除了 tf.decode_raw ,但是没有用

1 个答案:

答案 0 :(得分:0)

因为使用tf.decode_raw将图像转换为double(tf.float64)类型,其大小为8个字节。因此parsed ['mel_spec_raw']应该是8的倍数。您可以打印 parsed ['mel_spec_raw'] 的类型,它应该是 tf.string ,解释为什么parsed ['mel_spec_raw']的大小为165750。您可以将代码更改为:

try...catch
try...finally
try...catch...finally

它可能有效,因为tf.uint8的大小仅为1。如果要将类型转换为tf.float64,则可以使用tf.cast将类型转换为tf.float

# mel_spec1d = tf.decode_raw(parsed['mel_spec_raw'], tf.float64)
mel_spec1d = tf.decode_raw(parsed['mel_spec_raw'], tf.uint8)

希望这会对您有所帮助。