Tensorflow - 读取并保存TFRecords到Dict并使用多处理

时间:2017-11-07 16:22:59

标签: tensorflow python-multiprocessing

我正在努力加快select tfrecords到一系列python词典的转换。这就是我所拥有的。最初CPU利用率达到峰值,但几乎为零,表明我的代码无法正常工作。

我的目标是保存和腌制3个词典。有14,000多个tfrecord文件(2 gigs appx)。按目前的速度,单个流程运行大约需要84个小时。

我使用manage dicts是否有任何问题

import glob
import tensorflow as tf
import cPickle
import numpy as np
from tqdm import tqdm

import collections
from multiprocessing import Process, Manager, Pool



def get_multihot_encoding(example_label):
    enc = np.zeros(10)
    for label in example_label:
        if label in lookup.values():
            index = lookup_inverted[label]
            enc[index] = 1
    return list(enc)


# Set-up MultiProcessing
manager = Manager()
audio_embeddings_dict = manager.dict()
audio_labels_dict = manager.dict()
audio_multihot_dict = manager.dict()
sess = tf.Session()

# The iterable which gets passed to the function
all_tfrecord_filenames = glob.glob('/Users/jeff/features/audioset_v1_embeddings/unbal_train/*.tfrecord')



def process_tfrecord(tfrecord):

  for idx, example in enumerate(tf.python_io.tf_record_iterator(tfrecord)): 
    tf_example = tf.train.Example.FromString(example)
    vid_id = tf_example.features.feature['video_id'].bytes_list.value[0].decode(encoding='UTF-8')
    example_label = list(np.asarray(tf_example.features.feature['labels'].int64_list.value))

    # Non zero intersect of 2 sets is True - only create dict entries if this is true!
    if set(example_label) & label_filters:
        print(set(example_label) & label_filters, " Is the intersection of the two")
        tf_seq_example = tf.train.SequenceExample.FromString(example)
        n_frames = len(tf_seq_example.feature_lists.feature_list['audio_embedding'].feature)
        audio_frame = []    

        for i in range(n_frames):
            audio_frame.append(tf.cast(tf.decode_raw(
                 tf_seq_example.feature_lists.feature_list['audio_embedding'].feature[i].bytes_list.value[0],tf.uint8)
                ,tf.float32).eval(session=sess))

        audio_embeddings_dict[vid_id] = audio_frame
        audio_labels_dict[vid_id] = example_label
        audio_multihot_dict[vid_id] = get_multihot_encoding(example_label)
        #print(get_multihot_encoding(example_label), "Is the encoded label")

    if idx % 100 == 0:  
        print ("Saving dictionary at loop: {}".format(idx))
        cPickle.dump(audio_embeddings_dict, open('audio_embeddings_dict_unbal_train_multi_{}.pkl'.format(idx), 'wb'))
        cPickle.dump(audio_multihot_dict, open('audio_multihot_dict_bal_untrain_multi_{}.pkl'.format(idx), 'wb'))
        cPickle.dump(audio_multihot_dict, open('audio_labels_unbal_dict_multi_{}.pkl'.format(idx), 'wb'))




pool = Pool(50)
result = pool.map(process_tfrecord, all_tfrecord_filenames)

0 个答案:

没有答案