我需要多次执行语句sess.run()
。
我在代码的开头创建了sess
一次。但是,每个sess.run()
语句在我的CPU机器上花费大约0.5-0.8秒。有什么办法可以优化吗?由于Tensorflow会进行延迟加载,有什么方法可以让它不能做到这一点,并使其更快?
我正在使用图像分类示例中的Inception模型。
def load_network():
with gfile.FastGFile('model.pb', 'rb') as f:
graph_def = tf.GraphDef()
data = f.read()
graph_def.ParseFromString(data)
png_data = tf.placeholder(tf.string, shape=[])
decoded_png = tf.image.decode_png(png_data, channels=3)
_ = tf.import_graph_def(graph_def, name=input_map={'DecodeJpeg': decoded_png})
return png_data
def get_pool3(sess, png_data, imgBuffer):
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
pool3Vector = sess.run(pool3, {png_data: imgBuffer.getvalue()})
return pool3Vector
def main():
sess = getTensorSession()
png_data = load_network()
# The below line needs to be called multiple times, which is what takes
# nearly 0.5-0.8 seconds.
# imgBuffer contains the stored value of the image.
pool3 = get_pool3(sess, png_data, imgBuffer)
答案 0 :(得分:2)
Tensorflow懒惰地运行操作---在调用sess.run()之前实际上没有计算任何内容。当您调用sess.run()时,Tensorflow会执行计算图中的所有操作。因此,如果sess.run()花费0.5-0.8秒,那么你的计算本身可能需要0.5-0.8秒。
(sess.run()有一些开销,但它不应该接近半秒左右。)
希望有所帮助!
添加了:
以下是一些可以加快计算速度的方法:
使用Tensorflow的分析工具来查看计算的哪些部分花费时间。它们尚未记录,但您可以在此github问题中找到有关它们的一些信息:https://github.com/tensorflow/tensorflow/issues/1824
让您的计算更便宜 - 降低模型的复杂性,使用更小的图像等。
在GPU而不是CPU上运行计算。