我想将“vgg16.npy”加载到我自己编写的vggnet中。 我想知道你是否可以帮助我。谢谢你。
答案 0 :(得分:1)
您无法将.npy
文件的权重/偏差直接复制到您的网络,因为它们将被视为常量。因此,首先需要初始化tf.Variable
:
conv1_weights = tf.get_variable('conv1_weights', initializer=data['conv1']['weights'])
conv1_biases = tf.get_variable('conv1_biases', initializer=data['conv1']['biases'])
这是我的代码中的一个真实示例。 data 是 np.load 函数创建的对象。
然后,将其添加到您的网络:
conv_kernel_1 = tf.nn.conv2d(x, **conv1_weights**, strides=[1,2,2,1], padding='VALID', name='conv1')
bias_layer_1 = relu(tf.nn.bias_add(conv_kernel_1, conv1_biases))
我希望这个例子很明确。
答案 1 :(得分:0)
以下是我转换静态'的解决方案。 npy文件的长度为张量。 1.使用tf.read_file()获取原始字符串 2.根据https://docs.scipy.org/doc/numpy-dev/neps/npy-format.html,您可以使用tf.substr()从原始字符串中获取额外值 3.使用tf.py_func和struct.unpack将二进制原始字符串转换为tf.float64
希望这可以帮到你。 :)