我对tf.nn.conv2d(...)
的Tensorflow实施感到好奇。要调用它,只需运行tf.nn.conv2d(...)
即可。然而,我正在试图看到兔子洞被执行的地方。代码如下(箭头表示它最终调用的函数):
tf.nn.conv2d(...) -> tf.nn_ops.conv2d(...) -> tf.gen_nn_ops.conv2d(...) -> _op_def_lib.apply_op("Conv2D", ...) -> ?
我熟悉Tensorflow对LSTM的实现以及在人们认为合适的情况下轻松操作它们的能力。是用于用Python编写的conv2d()
计算执行的函数,如果是,它在哪里?我可以看到步幅的执行地点和方式吗?
答案 0 :(得分:24)
TL; DR: tf.nn.conv2d()
的实现是用C ++编写的,它使用Eigen(在CPU上)或cuDNN库(在GPU上)调用优化代码。您可以找到实施here。
您在问题中提到的函数链(来自tf.nn.conv2d()
)是用于构建 TensorFlow图的Python函数,但这些函数不会调用实现。回想一下,在TensorFlow中,您首先build a symbolic graph, then execute it。
tf.nn.conv2d()
的实现仅在您调用Session.run()
传递Tensor
时执行,input = tf.placeholder(tf.float32)
filter = tf.Variable(tf.truncated_normal([5, 5, 3, 32], stddev=0.1)
conv = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')
result = sess.run(conv, feed_dict={input: ...}) # <== Execution happens here.
的值取决于某些卷积的结果。例如:
sess.run(...)
调用conv
告诉TensorFlow运行所有需要计算sess.run()
值的操作,包括卷积本身。从这里到实现的路径有点复杂,但需要执行以下步骤:
conv
调用TensorFlow后端获取tensorflow::OpKernel
的值。Compute()
方法调用与卷积运算符对应的"Conv2D"
。 Compute()
OpKernel已实施here,其package entities;
import org.lwjgl.input.Keyboard;
import org.lwjgl.input.Mouse;
import org.lwjgl.opengl.Display;
import org.lwjgl.util.vector.Vector3f;
import renderEngine.DisplayManager;
import terrains.Terrain;
public class Camera
{
private float distanceFromPlayer = 50;
private float angleAroundPlayer = 0;
private boolean FPS = false;
private Vector3f position = new Vector3f(110, 10, -52);
private float pitch = 20;
private float yaw;
private float roll;
private static final float RUN_SPEED = 20;
private static final float GRAVITY = -50f;
private static final float JUMP_POWER = 30;
private int clickedOnPlayer = 0;
private float upwardsSpeed = 0;
private boolean isInAir = false;
private Player player;
public Camera(Player player)
{
this.player = player;
}
public void move(Terrain terrain)
{
Mouse.setGrabbed(true);
this.yaw = (Mouse.getX());
float MouseY = Mouse.getY();
if(MouseY > 360)
{
pitch = 465 - MouseY;
}
if (yaw >= 1079)
{
Mouse.setCursorPosition(Display.getWidth() / 2 + 80,
Display.getHeight() / 2);
this.yaw = 720;
}
else if (yaw <= 359)
{
Mouse.setCursorPosition(Display.getWidth() / 2 + 80,
Display.getHeight() / 2);
this.yaw = 720;
}
if(pitch <= -100)
{
pitch = -100;
}
float theta = yaw - 720;
float offsetX = (float) (2 * Math.sin(Math.toRadians(theta)));
float offsetZ = (float) (2 * Math.cos(Math.toRadians(theta)));
moveCamera();
upwardsSpeed += GRAVITY * DisplayManager.getFrameTimeSeconds();
position.y += upwardsSpeed * DisplayManager.getFrameTimeSeconds();
player.getPosition().y += upwardsSpeed * DisplayManager.getFrameTimeSeconds();
float terrainHeight = terrain.getHeightOfTerain(player.getPosition().x, player.getPosition().z);
if (position.y - 10 <= terrainHeight)
{
upwardsSpeed = 0;
isInAir = false;
position.y = terrainHeight + 10;
player.getPosition().y = terrainHeight;
}
player.getPosition().x = position.x - offsetX;
player.getPosition().z = position.z + offsetZ;
player.setRotY(180 - theta);
}
private void jump()
{
if (!isInAir)
{
this.upwardsSpeed = JUMP_POWER;
isInAir = true;
}
}
public void moveCamera()
{
float theta = yaw - 720;
float offsetX = (float) (2 * Math.sin(Math.toRadians(theta)));
float offsetZ = (float) (2 * Math.cos(Math.toRadians(theta)));
if (Keyboard.isKeyDown(Keyboard.KEY_LSHIFT))
{
if (Keyboard.isKeyDown(Keyboard.KEY_W))
{
position.z -= ((offsetZ * RUN_SPEED * 3) * DisplayManager.getFrameTimeSeconds());
position.x += ((offsetX * RUN_SPEED * 3) * DisplayManager.getFrameTimeSeconds());
}
else if (Keyboard.isKeyDown(Keyboard.KEY_S))
{
position.z += ((offsetZ * RUN_SPEED * 3) * DisplayManager.getFrameTimeSeconds());
position.x -= ((offsetX * RUN_SPEED * 3) * DisplayManager.getFrameTimeSeconds());
}
else
{
position.x = position.x;
position.z = position.z;
}
if (Keyboard.isKeyDown(Keyboard.KEY_D))
{
position.z += ((offsetX * RUN_SPEED * 3) * DisplayManager.getFrameTimeSeconds());
position.x += ((offsetZ * RUN_SPEED * 3) * DisplayManager.getFrameTimeSeconds());
}
else if (Keyboard.isKeyDown(Keyboard.KEY_A))
{
position.z -= ((offsetX * RUN_SPEED * 3) * DisplayManager.getFrameTimeSeconds());
position.x -= ((offsetZ * RUN_SPEED * 3) * DisplayManager.getFrameTimeSeconds());
}
else
{
position.x = position.x;
position.z = position.z;
}
if (Keyboard.isKeyDown(Keyboard.KEY_SPACE))
{
jump();
}
}
else
{
if (Keyboard.isKeyDown(Keyboard.KEY_W))
{
position.z -= ((offsetZ * RUN_SPEED) * DisplayManager.getFrameTimeSeconds());
position.x += ((offsetX * RUN_SPEED) * DisplayManager.getFrameTimeSeconds());
}
else if (Keyboard.isKeyDown(Keyboard.KEY_S))
{
position.z += ((offsetZ * RUN_SPEED) * DisplayManager.getFrameTimeSeconds());
position.x -= ((offsetX * RUN_SPEED) * DisplayManager.getFrameTimeSeconds());
}
else
{
position.x = position.x;
position.z = position.z;
}
if (Keyboard.isKeyDown(Keyboard.KEY_D))
{
position.z += ((offsetX * RUN_SPEED) * DisplayManager.getFrameTimeSeconds());
position.x += ((offsetZ * RUN_SPEED) * DisplayManager.getFrameTimeSeconds());
}
else if (Keyboard.isKeyDown(Keyboard.KEY_A))
{
position.z -= ((offsetX * RUN_SPEED) * DisplayManager.getFrameTimeSeconds());
position.x -= ((offsetZ * RUN_SPEED) * DisplayManager.getFrameTimeSeconds());
}
else
{
position.x = position.x;
position.z = position.z;
}
if (Keyboard.isKeyDown(Keyboard.KEY_SPACE))
{
jump();
}
}
}
public void invertPitch(){
this.pitch = -pitch;
}
public boolean getFPS()
{
return FPS;
}
public Vector3f getPosition()
{
return position;
}
public void setPosition(Vector3f position)
{
this.position = position;
}
public float getPitch()
{
return pitch;
}
public float getYaw()
{
return yaw;
}
public float getRoll()
{
return roll;
}
private void calculateCameraPosition(float horizDistance,
float verticDistance)
{
float theta = player.getRotY() + angleAroundPlayer;
float offsetX = (float) (horizDistance * Math
.sin(Math.toRadians(theta)));
float offsetZ = (float) (horizDistance * Math
.cos(Math.toRadians(theta)));
position.x = player.getPosition().x - offsetX;
position.z = player.getPosition().z - offsetZ;
position.y = player.getPosition().y + verticDistance;
}
private float calculateHorizontalDistance()
{
return (float) (distanceFromPlayer * Math.cos(Math.toRadians(pitch)));
}
private float calculateVerticalDistance()
{
return (float) (distanceFromPlayer * Math.sin(Math.toRadians(pitch)));
}
private void calculateZoom()
{
float zoomLevel = Mouse.getDWheel() * 0.1f;
distanceFromPlayer -= zoomLevel;
}
private void calculatePitch()
{
if (Mouse.isButtonDown(0))
{
float pitchChange = Mouse.getDY() * 0.1f;
pitch -= pitchChange;
}
}
private void calculateAngleAroundPlayer()
{
if (Mouse.isButtonDown(0))
{
float angleChange = Mouse.getDX() * 0.3f;
angleAroundPlayer -= angleChange;
}
}
}
方法为here。由于此操作对于许多工作负载而言性能至关重要,因此实现非常复杂,但基本思想是将计算卸载到Eigen Tensor库(如果在CPU上运行)或cuDNN的优化GPU实现。
答案 1 :(得分:1)
TensorFlow程序由两个不连续的部分组成:
tf.nn.conv2d(...) - &gt; tf.nn_ops.conv2d(...) - &gt; tf.gen_nn_ops.conv2d(...) - &gt; _op_def_lib.apply_op(“Conv2D”,...) - &gt; graph.create_op - &gt;将op注册到图中
sess = tf.Session(target) - &gt; sess.run(conv2d) - &gt; master prune full graph to client graph - &gt; master按任务分割客户端图形到图形分区 - &gt;将图分区注册到worker - &gt;工人按设备将子图拆分为图分区 - &gt;然后主人通知所有工人运行图分区 - &gt; worker通知所有设备运行图分区 - &gt;执行程序将通过设备上的拓扑排序运行操作。
对于其中一个操作系统,执行程序将调用内核实现来计算操作。
tf.nn.conv2d()的内核实现是用C ++编写的,它使用Eigen(在CPU上)或cuDNN库(在GPU上)调用优化代码。