使用Tensorflow从Google Audioset读取功能的最快方法

时间:2019-02-12 01:18:00

标签: tensorflow tensorflow-datasets

可能是How to decode vggish audioset embeddings from tfrecord?的副本,但需要在How can I extract the audio embeddings (features) from Google’s AudioSet?之上进行改进

我有以下代码可以使用tensorflow从Google的音频集中读取功能。它运行非常快,除了它尝试解析嵌入的行。当我分析此代码时,它表明它花费了大部分时间使用Extednddsession和TfSessionRecords。

它位于with tf.Session() as sess上下文中。

是否有一种更快的方式将数组读取为numpy并将其缓存在所有tfrecords的h5文件中?

def index_tfrecord(conn, h5file, tfrecord, split_index):
    cursor = conn.cursor()
    ebar = tqdm.tqdm(list(tf.python_io.tf_record_iterator(tfrecord)))
    for example in ebar:
        tf_example = tf.train.Example.FromString(example)

        f = tf_example.features.feature
        video_id = (f['video_id'].bytes_list.value[0]
                    ).decode(encoding='UTF-8')
        ebar.set_description('youtube.com/%s' % video_id)

        # TODO: This is not failproof. This may skip videos that have
        # multiple labels
        sql = 'SELECT COUNT(*) FROM labels_videos WHERE video_id = ?'
        cursor.execute(sql, (video_id,))
        num_rows = cursor.fetchone()[0]
        if num_rows:
            continue

        if len(f['start_time_seconds'].float_list.value) > 1:
            print(video_id)
        start_time_seconds = (f['start_time_seconds']
                              ).float_list.value[0]
        end_time_seconds = (f['end_time_seconds']
                            ).float_list.value[0]

        label_ids = list(np.asarray(
            tf_example.features.feature['labels'].int64_list.value))

        tf_seq_example = tf.train.SequenceExample.FromString(example)
        fl = (tf_seq_example.feature_lists
              ).feature_list['audio_embedding']
        n_frames = len(fl.feature)

        if video_id not in h5file.keys():
            audio_frames = [tf.cast(tf.decode_raw(
                fl.feature[i].bytes_list.value[0], tf.uint8), tf.uint8
                ).eval()
                for i in range(n_frames)]

            arr = np.array(audio_frames)
            h5file.create_dataset(
                name=video_id,
                data=arr)
            h5file.flush()
        else:
            logging.info('%s has already been indexed' % video_id)

        sql = (
            'INSERT INTO videos'
            '(video_id)'
            'VALUES (?)'
            )
        cursor.execute(sql, (
            video_id,))
        conn.commit()

        sql = (
            'INSERT INTO labels_videos'
            '(video_id, label_id, split_id, start_time_seconds,'
            ' end_time_seconds)'
            'VALUES (?, ?, ?, ?, ?)'
            )

        params = tuple((video_id, int(label_id), split_index,
                        start_time_seconds, end_time_seconds)
                       for label_id in label_ids)
        cursor.executemany(sql, params)

0 个答案:

没有答案