tflite输出与conv2d_transpose的tensorflow输出不匹配

时间:2019-02-04 18:58:52

标签: python numpy tensorflow

我正在尝试测试tflite中是否可以使用tf.nn.conv2d_transpose。我可以将我的模型无错误地转换为tflite,但是得到的结果与tensorflow不同。例如:

import tensorflow as tf
import numpy as np

np.random.seed(1234)
tf.random.set_random_seed(1234)

def trans_conv1d(x,
                 num_filters,
                 filter_length,
                 stride):
    batch_size, length, num_input_channels = x.get_shape().as_list()
    x = tf.reshape(x, [batch_size, 1, length, num_input_channels])

    weights = tf.get_variable('W', shape=(1, filter_length, num_filters, num_input_channels))
    biases = tf.get_variable('b', shape=(num_filters,))

    y = tf.nn.conv2d_transpose(
        x,
        filter=weights,
        output_shape=(batch_size, 1, stride * length, num_filters),
        strides=(1, 1, stride, 1),
        padding='SAME',
        data_format='NHWC',
        name="cnn2d")
    y = tf.nn.bias_add(y, biases)
    return y

num_filters = 4
filter_length = 40
stride = 8
x = tf.placeholder(dtype = tf.float32, shape = [1, 96, 2], name = "input")
y = trans_conv1d(x, num_filters, filter_length, stride)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
input_data = np.array(np.random.rand(1, 96, 2), dtype=np.float32)
output_data_tf = sess.run(y, feed_dict={x:input_data})
converter = tf.contrib.lite.TFLiteConverter.from_session(sess, [x], [y])
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
sess.close()

# tflite test
interpreter = tf.contrib.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test model on the same input data.
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()
output_data_tflite = interpreter.get_tensor(output_details[0]['index'])
print(np.array_equal(output_data_tf, output_data_tflite))

任何建议将不胜感激!

谢谢

0 个答案:

没有答案