Tensorflow:tf.nn.conv2d在哪里实际执行?

时间:2016-01-17 05:42:09

标签: python machine-learning tensorflow

我对tf.nn.conv2d(...)的Tensorflow实施感到好奇。要调用它,只需运行tf.nn.conv2d(...)即可。然而,我正在试图看到兔子洞被执行的地方。代码如下(箭头表示它最终调用的函数):

tf.nn.conv2d(...) -> tf.nn_ops.conv2d(...) -> tf.gen_nn_ops.conv2d(...) -> _op_def_lib.apply_op("Conv2D", ...) -> ?

我熟悉Tensorflow对LSTM的实现以及在人们认为合适的情况下轻松操作它们的能力。是用于用Python编写的conv2d()计算执行的函数,如果是,它在哪里?我可以看到步幅的执行地点和方式吗?

2 个答案:

答案 0 :(得分:24)

TL; DR: tf.nn.conv2d()的实现是用C ++编写的,它使用Eigen(在CPU上)或cuDNN库(在GPU上)调用优化代码。您可以找到实施here

您在问题中提到的函数链(来自tf.nn.conv2d())是用于构建 TensorFlow图的Python函数,但这些函数不会调用实现。回想一下,在TensorFlow中,您首先build a symbolic graph, then execute it

tf.nn.conv2d()的实现仅在您调用Session.run()传递Tensor时执行,input = tf.placeholder(tf.float32) filter = tf.Variable(tf.truncated_normal([5, 5, 3, 32], stddev=0.1) conv = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME') result = sess.run(conv, feed_dict={input: ...}) # <== Execution happens here. 的值取决于某些卷积的结果。例如:

sess.run(...)

调用conv告诉TensorFlow运行所有需要计算sess.run()值的操作,包括卷积本身。从这里到实现的路径有点复杂,但需要执行以下步骤:

  1. conv调用TensorFlow后端获取tensorflow::OpKernel的值。
  2. 后端修剪计算图以确定必须执行哪些节点,并将节点放在适当的设备(CPU或GPU)上。
  3. 指示每个设备使用executor执行其子图。
  4. 执行程序最终通过调用其Compute()方法调用与卷积运算符对应的"Conv2D"
  5. Compute() OpKernel已实施here,其package entities; import org.lwjgl.input.Keyboard; import org.lwjgl.input.Mouse; import org.lwjgl.opengl.Display; import org.lwjgl.util.vector.Vector3f; import renderEngine.DisplayManager; import terrains.Terrain; public class Camera { private float distanceFromPlayer = 50; private float angleAroundPlayer = 0; private boolean FPS = false; private Vector3f position = new Vector3f(110, 10, -52); private float pitch = 20; private float yaw; private float roll; private static final float RUN_SPEED = 20; private static final float GRAVITY = -50f; private static final float JUMP_POWER = 30; private int clickedOnPlayer = 0; private float upwardsSpeed = 0; private boolean isInAir = false; private Player player; public Camera(Player player) { this.player = player; } public void move(Terrain terrain) { Mouse.setGrabbed(true); this.yaw = (Mouse.getX()); float MouseY = Mouse.getY(); if(MouseY > 360) { pitch = 465 - MouseY; } if (yaw >= 1079) { Mouse.setCursorPosition(Display.getWidth() / 2 + 80, Display.getHeight() / 2); this.yaw = 720; } else if (yaw <= 359) { Mouse.setCursorPosition(Display.getWidth() / 2 + 80, Display.getHeight() / 2); this.yaw = 720; } if(pitch <= -100) { pitch = -100; } float theta = yaw - 720; float offsetX = (float) (2 * Math.sin(Math.toRadians(theta))); float offsetZ = (float) (2 * Math.cos(Math.toRadians(theta))); moveCamera(); upwardsSpeed += GRAVITY * DisplayManager.getFrameTimeSeconds(); position.y += upwardsSpeed * DisplayManager.getFrameTimeSeconds(); player.getPosition().y += upwardsSpeed * DisplayManager.getFrameTimeSeconds(); float terrainHeight = terrain.getHeightOfTerain(player.getPosition().x, player.getPosition().z); if (position.y - 10 <= terrainHeight) { upwardsSpeed = 0; isInAir = false; position.y = terrainHeight + 10; player.getPosition().y = terrainHeight; } player.getPosition().x = position.x - offsetX; player.getPosition().z = position.z + offsetZ; player.setRotY(180 - theta); } private void jump() { if (!isInAir) { this.upwardsSpeed = JUMP_POWER; isInAir = true; } } public void moveCamera() { float theta = yaw - 720; float offsetX = (float) (2 * Math.sin(Math.toRadians(theta))); float offsetZ = (float) (2 * Math.cos(Math.toRadians(theta))); if (Keyboard.isKeyDown(Keyboard.KEY_LSHIFT)) { if (Keyboard.isKeyDown(Keyboard.KEY_W)) { position.z -= ((offsetZ * RUN_SPEED * 3) * DisplayManager.getFrameTimeSeconds()); position.x += ((offsetX * RUN_SPEED * 3) * DisplayManager.getFrameTimeSeconds()); } else if (Keyboard.isKeyDown(Keyboard.KEY_S)) { position.z += ((offsetZ * RUN_SPEED * 3) * DisplayManager.getFrameTimeSeconds()); position.x -= ((offsetX * RUN_SPEED * 3) * DisplayManager.getFrameTimeSeconds()); } else { position.x = position.x; position.z = position.z; } if (Keyboard.isKeyDown(Keyboard.KEY_D)) { position.z += ((offsetX * RUN_SPEED * 3) * DisplayManager.getFrameTimeSeconds()); position.x += ((offsetZ * RUN_SPEED * 3) * DisplayManager.getFrameTimeSeconds()); } else if (Keyboard.isKeyDown(Keyboard.KEY_A)) { position.z -= ((offsetX * RUN_SPEED * 3) * DisplayManager.getFrameTimeSeconds()); position.x -= ((offsetZ * RUN_SPEED * 3) * DisplayManager.getFrameTimeSeconds()); } else { position.x = position.x; position.z = position.z; } if (Keyboard.isKeyDown(Keyboard.KEY_SPACE)) { jump(); } } else { if (Keyboard.isKeyDown(Keyboard.KEY_W)) { position.z -= ((offsetZ * RUN_SPEED) * DisplayManager.getFrameTimeSeconds()); position.x += ((offsetX * RUN_SPEED) * DisplayManager.getFrameTimeSeconds()); } else if (Keyboard.isKeyDown(Keyboard.KEY_S)) { position.z += ((offsetZ * RUN_SPEED) * DisplayManager.getFrameTimeSeconds()); position.x -= ((offsetX * RUN_SPEED) * DisplayManager.getFrameTimeSeconds()); } else { position.x = position.x; position.z = position.z; } if (Keyboard.isKeyDown(Keyboard.KEY_D)) { position.z += ((offsetX * RUN_SPEED) * DisplayManager.getFrameTimeSeconds()); position.x += ((offsetZ * RUN_SPEED) * DisplayManager.getFrameTimeSeconds()); } else if (Keyboard.isKeyDown(Keyboard.KEY_A)) { position.z -= ((offsetX * RUN_SPEED) * DisplayManager.getFrameTimeSeconds()); position.x -= ((offsetZ * RUN_SPEED) * DisplayManager.getFrameTimeSeconds()); } else { position.x = position.x; position.z = position.z; } if (Keyboard.isKeyDown(Keyboard.KEY_SPACE)) { jump(); } } } public void invertPitch(){ this.pitch = -pitch; } public boolean getFPS() { return FPS; } public Vector3f getPosition() { return position; } public void setPosition(Vector3f position) { this.position = position; } public float getPitch() { return pitch; } public float getYaw() { return yaw; } public float getRoll() { return roll; } private void calculateCameraPosition(float horizDistance, float verticDistance) { float theta = player.getRotY() + angleAroundPlayer; float offsetX = (float) (horizDistance * Math .sin(Math.toRadians(theta))); float offsetZ = (float) (horizDistance * Math .cos(Math.toRadians(theta))); position.x = player.getPosition().x - offsetX; position.z = player.getPosition().z - offsetZ; position.y = player.getPosition().y + verticDistance; } private float calculateHorizontalDistance() { return (float) (distanceFromPlayer * Math.cos(Math.toRadians(pitch))); } private float calculateVerticalDistance() { return (float) (distanceFromPlayer * Math.sin(Math.toRadians(pitch))); } private void calculateZoom() { float zoomLevel = Mouse.getDWheel() * 0.1f; distanceFromPlayer -= zoomLevel; } private void calculatePitch() { if (Mouse.isButtonDown(0)) { float pitchChange = Mouse.getDY() * 0.1f; pitch -= pitchChange; } } private void calculateAngleAroundPlayer() { if (Mouse.isButtonDown(0)) { float angleChange = Mouse.getDX() * 0.3f; angleAroundPlayer -= angleChange; } } } 方法为here。由于此操作对于许多工作负载而言性能至关重要,因此实现非常复杂,但基本思想是将计算卸载到Eigen Tensor库(如果在CPU上运行)或cuDNN的优化GPU实现。

答案 1 :(得分:1)

TensorFlow程序由两个不连续的部分组成:

  • 构建计算图。

tf.nn.conv2d(...) - &gt; tf.nn_ops.conv2d(...) - &gt; tf.gen_nn_ops.conv2d(...) - &gt; _op_def_lib.apply_op(“Conv2D”,...) - &gt; graph.create_op - &gt;将op注册到图中

  • 运行计算图。

sess = tf.Session(target) - &gt; sess.run(conv2d) - &gt; master prune full graph to client graph - &gt; master按任务分割客户端图形到图形分区 - &gt;将图分区注册到worker - &gt;工人按设备将子图拆分为图分区 - &gt;然后主人通知所有工人运行图分区 - &gt; worker通知所有设备运行图分区 - &gt;执行程序将通过设备上的拓扑排序运行操作。

对于其中一个操作系统,执行程序将调用内核实现来计算操作。

tf.nn.conv2d()的内核实现是用C ++编写的,它使用Eigen(在CPU上)或cuDNN库(在GPU上)调用优化代码。