我正在尝试测试tflite中是否可以使用tf.nn.conv2d_transpose。我可以将我的模型无错误地转换为tflite,但是得到的结果与tensorflow不同。例如:
import tensorflow as tf import numpy as np np.random.seed(1234) tf.random.set_random_seed(1234) def trans_conv1d(x, num_filters, filter_length, stride): batch_size, length, num_input_channels = x.get_shape().as_list() x = tf.reshape(x, [batch_size, 1, length, num_input_channels]) weights = tf.get_variable('W', shape=(1, filter_length, num_filters, num_input_channels)) biases = tf.get_variable('b', shape=(num_filters,)) y = tf.nn.conv2d_transpose( x, filter=weights, output_shape=(batch_size, 1, stride * length, num_filters), strides=(1, 1, stride, 1), padding='SAME', data_format='NHWC', name="cnn2d") y = tf.nn.bias_add(y, biases) return y num_filters = 4 filter_length = 40 stride = 8 x = tf.placeholder(dtype = tf.float32, shape = [1, 96, 2], name = "input") y = trans_conv1d(x, num_filters, filter_length, stride) sess = tf.Session() sess.run(tf.global_variables_initializer()) input_data = np.array(np.random.rand(1, 96, 2), dtype=np.float32) output_data_tf = sess.run(y, feed_dict={x:input_data}) converter = tf.contrib.lite.TFLiteConverter.from_session(sess, [x], [y]) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) sess.close() # tflite test interpreter = tf.contrib.lite.Interpreter(model_path="converted_model.tflite") interpreter.allocate_tensors() # Get input and output tensors. input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() # Test model on the same input data. interpreter.set_tensor(input_details[0]['index'], input_data) interpreter.invoke() output_data_tflite = interpreter.get_tensor(output_details[0]['index']) print(np.array_equal(output_data_tf, output_data_tflite))
任何建议将不胜感激!
谢谢