当我学习一个tensorflow项目时,找到一行代码:
cls_prob, box_pred = sess.run([output_cls_prob, output_box_pred], feed_dict={input_img: blob})
但是,此行代码花费了很多时间。 (使用CPU需要15秒...┭┮﹏┭┮)
通过咨询信息,我发现使用功能“数据集”可以解决此问题,这花费了很多时间,我应该如何使用它?
“ blob”的来源:
img = cv2.imread('./imgs/001.jpg')
img_scale = float(600) / min(img_data.shape[0], img_data.shape[1])
if np.round(img_scale * max(img_data.shape[0], img_data.shape[1])) > 1200:
img_scale = float(1200) / max(img_data.shape[0], img_data.shape[1])
img_data = cv2.resize(img_data, None, None, fx=img_scale, fy=img_scale, interpolation=cv2.INTER_LINEAR)
img_orig = img_data.astype(np.float32, copy=True)
blob = np.zeros((1, img_data.shape[0], img_data.shape[1], 3),dtype=np.float32)
blob[0, 0:img_data.shape[0], 0:img_data.shape[1], :] = img_orig
'output_cls_prob'&'output_box_pred'&'input_img'的来源:
# Actually,read PB model...
input_img = sess.graph.get_tensor_by_name('Placeholder:0')
output_cls_prob = sess.graph.get_tensor_by_name('Reshape_2:0')
output_box_pred = sess.graph.get_tensor_by_name('rpn_bbox_pred/Reshape_1:0')
参数类型:
blob:type 'numpy.ndarray'
output_cls_prob:class 'tensorflow.python.framework.ops.Tensor'
output_box_pred:class 'tensorflow.python.framework.ops.Tensor'
input_img:class 'tensorflow.python.framework.ops.Tensor'
答案 0 :(得分:1)
tf.data
是用于tensorflow输入管道的推荐API。这是有关tensorflow.org的教程。对于您的示例,第"Decoding image data and resizing it"节可能是最有用的。例如,您可以执行以下操作:
# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string)
image_resized = tf.image.resize_images(image_decoded, [new_width, new_height])
image_resized = tf.expand_dims(image_resized, 0) # Adds size 1 dimension
return image_resized
# A vector of filenames.
filenames = tf.constant(["./imgs/001.jpg", ...])
dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.map(_parse_function)
并不要将input_img
用作占位符,而是进行更改:
input_img = tf.placeholder(tf.float32)
output_class_prob, output_class_pred = (... use input_img ...)
收件人:
iterator = dataset.make_one_shot_iterator()
input_img = iterator.get_next()
output_class_prob, output_class_pred = (... use input_img ...)
答案 1 :(得分:0)
首先,您应该知道使用多个GPU时,使用数据集API会对性能产生重大影响...否则,几乎与feed_dict相同。我建议您从TF开发人员那里阅读this other answer,它几乎是所有需要了解的关于创建这种新API好处的思想的信息。