Tensorflow:如何将“.npy”文件加载到网络中

时间:2017-11-01 15:16:32

标签: python tensorflow

我想将“vgg16.npy”加载到我自己编写的vggnet中。 我想知道你是否可以帮助我。谢谢你。

2 个答案:

答案 0 :(得分:1)

您无法将.npy文件的权重/偏差直接复制到您的网络,因为它们将被视为常量。因此,首先需要初始化tf.Variable

conv1_weights = tf.get_variable('conv1_weights', initializer=data['conv1']['weights'])

conv1_biases = tf.get_variable('conv1_biases', initializer=data['conv1']['biases'])

这是我的代码中的一个真实示例。 data np.load 函数创建的对象。

然后,将其添加到您的网络:

conv_kernel_1 = tf.nn.conv2d(x, **conv1_weights**, strides=[1,2,2,1], padding='VALID', name='conv1')

bias_layer_1 = relu(tf.nn.bias_add(conv_kernel_1, conv1_biases))

我希望这个例子很明确。

答案 1 :(得分:0)

以下是我转换静态'的解决方案。 npy文件的长度为张量。 1.使用tf.read_file()获取原始字符串 2.根据https://docs.scipy.org/doc/numpy-dev/neps/npy-format.html,您可以使用tf.substr()从原始字符串中获取额外值 3.使用tf.py_func和struct.unpack将二进制原始字符串转换为tf.float64

希望这可以帮到你。 :)