我了解到将数据放入计算图中是Tensorflow的主要瓶颈,并且已经读取到通过feed_dict
传递数据通常很慢,应该避免。为了测试这一点,我构建了两个实验,它们的实际计算非常简单,因此传递数据应该是时间成本的很大一部分。
feed_dict
测试:import tensorflow as tf
import numpy as np
from time import time
batch_size = 128
num_runs = 200
# Evaluation using feed_dicts
tf.reset_default_graph()
rng = np.random.RandomState(12345) # Seed rng to check results are identical
x_ph = tf.placeholder(tf.float32, shape=[None, 50])
w_ph = tf.placeholder(tf.float32, shape=[None, 50])
y = tf.reduce_mean(2*x_ph*w_ph)
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
sess.run(tf.global_variables_initializer())
# Generate random data
x_data = rng.randn(batch_size*num_runs, 50).astype(np.float32) + 0.1
w_data = rng.randn(batch_size*num_runs, 50).astype(np.float32) + 0.1
# Just to eliminate the possibility that creating these views takes time
datasets_x = [x_data[batch_size*r:batch_size*(r+1)] for r in range(num_runs)]
datasets_w = [w_data[batch_size*r:batch_size*(r+1)] for r in range(num_runs)]
start_time_s = time()
running_sum = 0
for r in range(num_runs):
running_sum += sess.run(y, feed_dict={x_ph: datasets_x[r], w_ph: datasets_w[r]})
elapsed_time_s = time() - start_time_s
print(f'Took {elapsed_time_s:.2f} seconds.')
print(f'Running sum = {running_sum}')
Dataset
测试import tensorflow as tf
import numpy as np
from time import time
batch_size = 128
num_runs = 200
# Evaluation using a dataset
tf.reset_default_graph()
rng = np.random.RandomState(12345) # Seed rng to check results are identical
data_x_ph = tf.placeholder(tf.float32, [None, 50])
data_w_ph = tf.placeholder(tf.float32, [None, 50])
dataset = tf.data.Dataset.from_tensor_slices((data_x_ph, data_w_ph))
iterator = dataset.batch(batch_size).make_initializable_iterator()
next_element = iterator.get_next()
x_tf = next_element[0]
w_tf = next_element[1]
y = tf.reduce_mean(2*x_tf*w_tf)
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
sess.run(tf.global_variables_initializer())
# # Generate random data
x_data = rng.randn(batch_size*num_runs, 50).astype(np.float32) + 0.1
w_data = rng.randn(batch_size*num_runs, 50).astype(np.float32) + 0.1
# # Just to eliminate the possibility that creating these views takes time
# datasets_x = [x_data[batch_size*r:batch_size*(r+1)] for r in range(num_runs)]
# datasets_w = [w_data[batch_size*r:batch_size*(r+1)] for r in range(num_runs)]
start_time_s = time()
running_sum = 0
sess.run(iterator.initializer, feed_dict={data_x_ph: x_data, data_w_ph: w_data})
init_time_s = time() - start_time_s
for r in range(num_runs):
running_sum += sess.run(y)
elapsed_time_s = time() - start_time_s
print(f'Took {elapsed_time_s:.2f} seconds.')
print(f'{init_time_s:.2f} of this was initialization time.')
print(f'Running sum = {running_sum}')
现在,当我在Jupyter笔记本中运行它们时,feed_dict
测试的输出是:
Took 0.15 seconds.
Running sum = 4.138764334471489
Dataset
测试:
Took 0.26 seconds.
0.03 of this was initialization time.
Running sum = 4.138764334471489
Dataset
示例要比feed_dict
示例慢,并且不是缓慢似乎是由于在iterator.initialize
中加载了numpy数组。 / p>
如果我尝试从终端运行,通过将其从笔记本单元复制到文件feed_dict.py
和dataset.py
中,我将得到此信息(剪切出各种TF记录消息):
python feed_dict.py
Took 0.74 seconds.
Running sum = 4.138764334471489
python dataset.py
Took 0.88 seconds.
0.03 of this was initialization time.
Running sum = 4.138764334471489
终端上的日志消息表明所有内容都已放置在GPU上,并且正在添加
import os
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
在这些文件的顶部以禁用此功能,将显着更改运行时间。总结:
| Environment | Time using `dataset` [s] | Time using `feed_dict` [s] |
| ----------------- | ------------------------ | -------------------------- |
| Jupyter | 0.26 | 0.15 |
| Commandline (GPU) | 0.88 | 0.74 |
| Commandline (CPU) | 0.14 | 0.06 |
天真的结论似乎是feed_dict
比Dataset
快!但这与我读过的所有内容相冲突。谁能解释我所看到的内容?