Theano和tensorflow conv2D产生不同的输出

时间:2018-02-21 16:27:57

标签: python tensorflow deep-learning theano

下面的代码用theano编写并产生一些输出

import theano
import tensorflow as tf
from theano import tensor as T
from theano.tensor.nnet import conv2d

import numpy as np

np.random.seed(1)
wt_th = np.random.uniform(low=-0.5, high=0.5, size=(16,3,5,5)).astype(np.float32)

np.random.seed(1)
inp = np.random.rand(128,3,128,128).astype(np.float32)

np.random.seed(1)
bias = np.random.uniform(low=0, high=10, size=(16)).astype(np.float32)

# instantiate 4D tensor for input
input = T.tensor4(name='input')

# initialize shared variable for weights.
w_shp = (16, 3, 5, 5)
W = theano.shared(wt_th , name ='W')

b_shp = (16,)
b = theano.shared(bias, name ='b')

# build symbolic expression that computes the convolution of input with filters in w
conv_out = conv2d(input, W)

output = T.nnet.relu(conv_out + b.dimshuffle('x', 0, 'x', 'x'))

# create theano function to compute filtered images
f = theano.function([input], output)
res = f(inp)
print res[0][0][0]

[2.8236432 2.094213 1.4916432 3.525494 2.3700824 1.8851945 2.2574215  3.3974087 1.9719648 1.3346338 0.7322583 1.4527869 2.9211016 2.4242344  2.6613848 3.1885512 2.8935843 3.7721367 1.0875871 1.8844371 3.6890957  2.1210446 3.4621592 2.2298138 2.1788187 3.1571674 2.080009 1.4983883  3.3549118 1.8853223 2.0242834 2.8072758 4.2562714 3.6012995 2.2535224  3.87668 3.3886926 3.697033 3.3373523 2.2016246 3.5874677 3.0154514  2.434566 3.6492867 2.2965183 2.6377907 2.2562 2.8330164 2.1103406  2.9778543 2.3738375 3.1129453 1.277472 1.1789091 2.4199317 2.619667  2.5976152 1.0020001 2.562955 1.6254797 1.9258347 1.5564928 3.5225492  3.1682463 2.179951 3.2768161 2.2703805 2.0199404 2.4948874 2.9022932  3.0263028 2.264034 1.9042997 1.6110027 3.6300693 1.899374 2.9140353  2.8552768 2.7125297 2.7972744 2.0619967 3.8458047 3.140479 1.6845248  3.844461 3.8562043 2.5270283 2.4488764 2.7029114 1.8886952 3.034019  3.1078124 1.9806297 4.573 2.769538 2.6645966 3.501518 2.2144883  1.8297508 3.3294327 2.7242799 2.187298 2.5060043 1.9938259 3.914175  3.7276266 2.6536622 2.896241 2.821738 1.592206 1.8782039 2.648998  2.284129 3.4120197 2.6911411 3.2339904 2.5738459 2.8637185 1.8006318  3.1124763 2.1838622 2.6475391 1.7801914 2.5641136]

W的格式为(num_output_channel,num_input_channel,height,widht)

输入的格式为(batch_size,num_channels,height,width)

现在我编写了一个函数来将权重从theano格式转换为tensorflow格式,并将输入形状更改为tf格式 但是下面的代码产生的输出不同于上面的代码

def convert_filter(wts):
    wts = np.moveaxis(wts, 0, 3)
    wts = np.moveaxis(wts, 0, 2)
    return wts

def convert_input(inp):
    inp = np.moveaxis(inp,1,3)
    return inp

input_shape = [128,128, 3]

X = tf.placeholder(shape = [None] + input_shape, dtype=tf.float32, name='X')

wt_tf = convert_filter(wt_th).astype(np.float32)

conv_kernel_1 = tf.nn.conv2d(X, wt_tf, [1, 1, 1, 1], padding='VALID',use_cudnn_on_gpu=False)

bias_layer_1 = tf.nn.bias_add(conv_kernel_1, bias)

act_out = tf.nn.relu(bias_layer_1)

inp_tf = convert_input(inp)

with tf.Session() as sess:
    out = sess.run(act_out, feed_dict={X:inp_tf})
    out = np.moveaxis(out, 3, 1)
print out[0][0][0]

[2.67197418 2.15853548 2.06719136 3.01160574 3.22447252 3.077492   2.52125549 3.08207083 2.29633474 1.86849833 2.03281307 2.28387547   2.67936897 2.48002243 2.31078005 3.56169009 3.12560081 2.61774731   1.82814527 3.23375154 3.25905514 2.39252329 3.13444471 2.00132608   2.41169739 1.86714172 3.01640558 2.51328039 2.07797813 1.77424145   1.8954494 2.98585939 2.98480368 2.57455826 2.36318088 3.88532543   2.38877392 2.86067486 2.78855133 2.63732243 2.63163185 2.79659152   1.98354578 2.77975321 2.12787509 2.71589994 3.44908381 2.02305984   3.04079533 2.60647154 2.14657426 2.74537277 3.07799053 2.49051762   4.77739191 3.12529612 2.30980444 2.31344223 2.02293968 3.04298592   3.4453795 3.58379078 3.32912683 3.26278138 1.48381591 2.32841253   1.97166562 3.04377413 3.12559581 2.27840328 2.93908429 0.96808767   3.17380023 1.60673594 2.59704685 3.98458505 1.25713849 1.90271974   1.82997131 2.93574715 2.14195251 3.26882362 2.09072447 2.07539392   3.77434778 1.82215333 3.30864692 1.52123737 2.29328823 1.36722493   3.34969425 2.55285358 3.15811181 4.44630671 2.7549541 2.83824682   2.50485158 2.45610046 1.5423398 3.12460995 2.38987827 0.983325   2.64392757 3.11031628 1.41283321 2.58364391 2.17403984 3.19049454   2.83069992 1.04926252 2.93791962 2.37773943 3.51300693 3.02249169   2.59249544 1.81437802 3.34520817 3.04475498 4.02190208 3.84745455   2.45946741 2.06334138 2.11823249 2.95765638]

为什么输出不同?

0 个答案:

没有答案