如何在Tensorflow中填充元组数据进行批处理?

时间:2018-03-20 08:30:44

标签: python tensorflow

我无法在Tensorflow中执行填充批处理dataset.padded_batch(),这是我的代码:

此代码抛出错误,无法正常工作,我认为是因为某些原因我使用自己的代码来计算MFCC,但是如果我不调用padded_batch它会起作用

我收到以下错误:

“InvalidArgumentError:批处理中的所有元素必须与component0的填充形状具有相同的排名:预期排名1但得到排名为2的元素      [[Node:IteratorGetNext = IteratorGetNextoutput_shapes = [[?,?],[?,?],[?,?]],output_types = [DT_DOUBLE,DT_STRING,DT_INT64],_ device =“/ job:localhost / replica:0 /任务:0 /设备:CPU:0“]]”

def _read_py_function(audio, label):
    audio = audio_to_mfcc(audio)
    original_length=audio.shape[0]

    #if audio.shape[0] < timesteps:
     #   original_length=audio.shape[0]
       # print(original_length)

    #elif audio.shape[0] >= timesteps:
     #   original_length=timesteps


    #audio=normalized(pad(audio) , axis=1 )
   # audio=pad(audio)

    return audio ,label, original_length

dataset = tf.contrib.data.TextLineDataset("hardik_250_docker.csv")
dataset=dataset.map(decode_csv)



dataset = dataset.map(
    lambda audio, label: tuple(tf.py_func(
        _read_py_function, [audio, label], [tf.double, label.dtype, tf.int64])))


shapes=([None]  , [None], [None])
dataset=dataset.padded_batch(2, padded_shapes=shapes)

0 个答案:

没有答案