可能是How to decode vggish audioset embeddings from tfrecord?的副本,但需要在How can I extract the audio embeddings (features) from Google’s AudioSet?之上进行改进
我有以下代码可以使用tensorflow从Google的音频集中读取功能。它运行非常快,除了它尝试解析嵌入的行。当我分析此代码时,它表明它花费了大部分时间使用Extednddsession和TfSessionRecords。
它位于with tf.Session() as sess
上下文中。
是否有一种更快的方式将数组读取为numpy并将其缓存在所有tfrecords的h5文件中?
def index_tfrecord(conn, h5file, tfrecord, split_index):
cursor = conn.cursor()
ebar = tqdm.tqdm(list(tf.python_io.tf_record_iterator(tfrecord)))
for example in ebar:
tf_example = tf.train.Example.FromString(example)
f = tf_example.features.feature
video_id = (f['video_id'].bytes_list.value[0]
).decode(encoding='UTF-8')
ebar.set_description('youtube.com/%s' % video_id)
# TODO: This is not failproof. This may skip videos that have
# multiple labels
sql = 'SELECT COUNT(*) FROM labels_videos WHERE video_id = ?'
cursor.execute(sql, (video_id,))
num_rows = cursor.fetchone()[0]
if num_rows:
continue
if len(f['start_time_seconds'].float_list.value) > 1:
print(video_id)
start_time_seconds = (f['start_time_seconds']
).float_list.value[0]
end_time_seconds = (f['end_time_seconds']
).float_list.value[0]
label_ids = list(np.asarray(
tf_example.features.feature['labels'].int64_list.value))
tf_seq_example = tf.train.SequenceExample.FromString(example)
fl = (tf_seq_example.feature_lists
).feature_list['audio_embedding']
n_frames = len(fl.feature)
if video_id not in h5file.keys():
audio_frames = [tf.cast(tf.decode_raw(
fl.feature[i].bytes_list.value[0], tf.uint8), tf.uint8
).eval()
for i in range(n_frames)]
arr = np.array(audio_frames)
h5file.create_dataset(
name=video_id,
data=arr)
h5file.flush()
else:
logging.info('%s has already been indexed' % video_id)
sql = (
'INSERT INTO videos'
'(video_id)'
'VALUES (?)'
)
cursor.execute(sql, (
video_id,))
conn.commit()
sql = (
'INSERT INTO labels_videos'
'(video_id, label_id, split_id, start_time_seconds,'
' end_time_seconds)'
'VALUES (?, ?, ?, ?, ?)'
)
params = tuple((video_id, int(label_id), split_index,
start_time_seconds, end_time_seconds)
for label_id in label_ids)
cursor.executemany(sql, params)