Tensorflow - 将修改后的MNIST示例移植到Android中

时间:2017-10-25 09:51:24

标签: android python tensorflow

以下是我用来生成我即将使用的神经网络的CNN模型函数:

def cnn_model_fn(features, labels, mode):
  """Model function for CNN."""
  x = features["x"]
  input_layer = tf.reshape(x, [-1, 64, 64, 3])

  # The rest of the CNN goes here, utilizing input_layer as the input for the conv1 layer, similar to the MNIST handwritten digits example except with more conv-pool layers and eight classes

简单地说,我想在PC中使用Python训练这一点,并确保这与Android兼容,这意味着将正在进行的冻结graph.pb送入 TensorFlowInferenceInterface

在我使用的Android部分中:

private float[] getPixels(Mat img, double mean, double std){
    int inputSize = 64;
    int channels = 3;
    Bitmap bmp = Bitmap.createBitmap(inputSize, inputSize, Bitmap.Config.ARGB_8888);
    Utils.matToBitmap(img, bmp);
    int[] intValues = new int[inputSize * inputSize];
    bmp.getPixels(intValues, 0, inputSize, 0, 0, inputSize, inputSize);

    float[] floatValues = new float[inputSize * inputSize * channels];
    for (int i = 0; i < intValues.length; ++i) {
        final int val = intValues[i];
        floatValues[i * 3 + 0] = (float)(((double)(((val >> 16) & 0xFF)) - mean) / std);
        floatValues[i * 3 + 1] = (float)(((double)(((val >> 8) & 0xFF)) - mean) / std);
        floatValues[i * 3 + 2] = (float)(((double)((val & 0xFF)) - mean) / std);
    }

    return floatValues;
}

通过此调用从OpenCV Mat中检索平坦像素我想要输入TFII:

tfii.feed("enqueue_input/random_shuffle_queue", getPixels(colorResizedSubmat, mu.get(0, 0)[0], sigma.get(0, 0)[0]), 1, 64, 64, 3);

但我得到的只是一个例外,说:

java.lang.IllegalArgumentException: No OpKernel was registered to support Op 'RandomShuffleQueueV2' with these attrs.  Registered devices: [CPU], Registered kernels: <no registered kernels>

[[Node: enqueue_input/random_shuffle_queue = RandomShuffleQueueV2[capacity=1000, component_types=[DT_INT64, DT_INT32, DT_FLOAT], container="", min_after_dequeue=250, seed=1, seed2=4, shapes=[[], [], [64,64,3]], shared_name=""]()]]

我到处寻找这个问题,我得到的唯一线索是x tf.Placeholder()而不是字典,我不认为我可以把它改成占位符,因为我正在使用一个numpy_input_fn,我甚至不确定cnn_model_fn是否在tf.Session()中运行。

我该如何解决这个问题?

编辑:如果不使用Estimator API重写整个内容,是否有可能解决它?

0 个答案:

没有答案