tensorflow错误:InvalidArgumentError:元组组件1中的形状不匹配。预期为[1],得到为[5]

时间:2018-07-03 03:19:48

标签: python-2.7 tensorflow

我正在尝试构造一批(wav_file,label)对。 WAV文件标签和路径在dev.csv中列出。 下面的代码不起作用,

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

threads = 1
batch_size = 5
global record_defaults
record_defaults = [['/Users/phoenix/workspace/dataset/data_thchs30/dev/A11_101.wav'], ['8.26'], ['七十 年代 末 我 外出 求学 母亲 叮咛 我 吃饭 要 细嚼慢咽 学习 要 深 钻 细 研']]

def read_record(filename_queue, num_records):
    reader = tf.TextLineReader()
    key, value = reader.read_up_to(filename_queue, num_records)
    wav_filename, duration, transcript = tf.decode_csv(value, record_defaults, field_delim=",")

    wav_reader = tf.WholeFileReader()
    wav_key, wav_value = wav_reader.read_up_to(tf.train.string_input_producer(wav_filename, shuffle=False, capacity=num_records), num_records)
    return [wav_key, transcript] # throw errors
    # return [wav_key, wav_value]  # works
    # return [wav_filename, duration, transcript]   # works

data_queue = tf.train.string_input_producer(tf.train.match_filenames_once('dev.csv'), shuffle=False)  
batch_data = [read_record(data_queue, batch_size) for _ in range(threads)]
capacity = threads * batch_size
batch_values = tf.train.batch_join(batch_data, batch_size=batch_size, capacity=capacity, enqueue_many=True)

init = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init)
    sess.run(tf.initialize_local_variables())
    coord = tf.train.Coordinator()
    print(coord)
    threads = tf.train.start_queue_runners(coord=coord)
    print("threads num: " + str(threads))
    try:
        step = 0
        while not coord.should_stop():
            step += 1
            feat = sess.run([batch_values])
            print("line:", step, feat)

    except tf.errors.OutOfRangeError:
        print(' training for 1 epochs, %d steps', step)
    finally:
        coord.request_stop()
        coord.join(threads)

在下面抛出错误,我该如何解决?:

enter image description here

dev.csv内容如下:

  

/Users/phoenix/workspace/dataset/data_thchs30/dev/A11_101.wav,8.26,《时代》杂志的外文翻译

     

/Users/phoenix/workspace/dataset/data_thchs30/dev/A11_119.wav,6.9,陈云彤世要秋风感冒不忍人雪雪

1 个答案:

答案 0 :(得分:0)

我试图这样重写您的代码。

这是我的观察。

  1. 不再抛出该错误。并返回值。
  2. 一个明显的差异是,脚本的批处理大小是指定大小的两倍。因此它是4而不是2。由于某种原因,它加倍了。音频二进制文件没有这种问题。

  3.   

    shapes=[tf.TensorShape(()),tf.TensorShape(batch_size,)]是基于我看到的一个错误,该错误提到我必须使用TensorShape指定此错误。我没有找到任何帮助的documentation,但在那里被提及。

  

shapes :(可选。)完整定义的TensorShape对象的列表,其长度与dtypes相同,或者为None。

import tensorflow as tf

tf.logging.set_verbosity(tf.logging.DEBUG)

FLAGS = tf.app.flags.FLAGS

threads = 1
batch_size = 2
record_defaults = [['D:/male.wav'], ['8.26'], ['七十 年代 末 我 外出 求学 母亲 叮咛 我 吃饭 要 细嚼慢咽 学习 要 深 钻 细 研']]


def readbatch(data_queue) :

    reader = tf.TextLineReader()
    _, rows = reader.read_up_to(data_queue, batch_size)
    wav_filename, duration, transcript = tf.decode_csv(rows, record_defaults,field_delim=",")
    audioreader = tf.WholeFileReader()
    _, audio = audioreader.read( tf.train.string_input_producer(wav_filename) )
    return [audio,transcript]

data_queue = tf.train.string_input_producer(
                tf.train.match_filenames_once('D:/Book1.csv'), shuffle=False)

batch_data = [readbatch(data_queue) for _ in range(threads)]
capacity = threads * batch_size
batch_values = tf.train.batch_join(batch_data, shapes=[tf.TensorShape(()),tf.TensorShape(batch_size,)], capacity=capacity, batch_size=batch_size, enqueue_many=False )

init = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init)
    sess.run(tf.initialize_local_variables())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    try:
        step = 0
        while not coord.should_stop():
            step += 1
            feat = sess.run([batch_values])
            audio = feat[0][0]
            print ('Size of audio is ' + str(audio.size))
            script = feat[0][1]
            print ('Size of script is ' + str(script.size))
    except tf.errors.OutOfRangeError:
        print(' training for 1 epochs, %d steps', step)
    finally:
        coord.request_stop()
        coord.join(threads)

样本数据集证明存在额外的一对。

[[array([b'Text2', b'Text1'], dtype=object), array([[b'Translation-1', b'Translation-2'],
       [b'Translation-1', b'Translation-2']], dtype=object)]]