如何运行手工编码的TF网络(无需培训)

时间:2018-07-10 11:36:48

标签: tensorflow

我创建了一个小型TF网络,尝试在其中手动初始化所​​有权重。我正在尝试给它一个特定的输入(所有输入),看看网络将产生什么。我这样做的原因是,我试图在R中复制TF模型并且存在一些差异,并且我想通过尝试在TF和R中复制一个小模型并比较结果来进行调试。

因此,模型的输入创建为:

shape = (1, 3, 128, 64)
x = np.ones(shape).astype(np.float32)
x = np.transpose(x, (0, 2, 3, 1))  # Convert to NHWC
network = tf.convert_to_tensor(x, dtype=tf.float32)

拥有一个网络非常简单:

Conv2d -> Batch Norm (BN) -> ELU activation -> Flatten -> Dense -> BN -> ELU

因此,我如下创建网络:

# w is a numpy array with the shape in (NHWC) format
network = tf.layers.conv2d(inputs=network, filters=32, kernel_size=(3, 3), padding='SAME', use_bias=False,
                           trainable=False, kernel_initializer=tf.constant_initializer(w))


# Batch normalization (The initialization variables are numpy arrays)
network = tf.nn.batch_normalization(x=network, mean=tf.convert_to_tensor(m), variance=tf.convert_to_tensor(v),
                                    offset=tf.convert_to_tensor(b), scale=tf.convert_to_tensor(s), variance_epsilon=1e-8)

# Activation
network = tf.nn.elu(network)

# Flatten the network
network = tf.layers.flatten(network)

# Fully connected
network = tf.layers.dense(inputs=network, units=128, use_bias=False, trainable=False)

# Batch normalization
network = tf.nn.batch_normalization(x=network, mean=tf.convert_to_tensor(m), variance=tf.convert_to_tensor(v),
                                    offset=tf.convert_to_tensor(b), scale=tf.convert_to_tensor(s), variance_epsilon=1e-8)

out = tf.nn.elu(network)

我认为我已经正确创建了网络。但是,现在我不知道如何运行它。我看过的在线示例似乎在训练和训练,以节省权重并重新加载保存的图形,但是我想知道是否有一种简单的方法可以简单地运行前向通过(我不需要进行任何训练,并且我已经进行了硬编码权重)并获得128维矢量,以便我可以验证输出?

1 个答案:

答案 0 :(得分:1)

const { MarkerDest, MarkerUpcoming, MarkerNext } = {
  MarkerDest: 'MarkerDest',
  MarkerUpcoming: 'MarkerUpcoming',
  MarkerNext: 'MarkerNext',
};

const markers = {
  [MarkerDest]: { foo: 'Bar' },
 };

with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # to actually initialize variables results = sess.run(out) # forward pass 现在是一个numpy数组,其中包含给定您在第一个代码块中创建的输入的网络输出。