我训练了一个带有input_shape=[125,100,100,1]
的模型来预测8个花车。我在演示中更改了these options以符合我的模型设置。
然后我为批量大小添加了另一个选项
private static final int BATCH_SIZE = 125;
在the C++ side中,我打印了一些调试信息,以查看我的张量的形状:
LOG (INFO) << "input node: " << input_tensors[0].first << ", "
<< "input shape: " << input_tensors[0].second.shape().DebugString();
tensorflow_inference_jni.cc:198输入节点:input_node,输入形状: [125,100,100,1]
但应用程序在调用vars->session->Run()
函数
A/libc: Fatal signal 6 (SIGABRT), code -6 in tid 16574 (InferenceThread)
现在,如果我设置BATCH_SIZE = 1
(始终使用批量大小为125的模型进行处理),应用程序不会崩溃,但会返回此错误:
E/native: tensorflow_inference_jni.cc:213 Error during inference: Invalid argument: Input to reshape is a tensor with 8 values, but the requested shape has 1000
[[Node: output_node = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](fullyconnected2_1/BiasAdd, output_node/shape)]]
此错误中请求的形状1000是num_output * batch_size我猜(8 * 125)。
我错过了什么吗?答案 0 :(得分:0)
我已经对训练中的批量大小进行了硬编码,因此我必须使用形状为[125,100,100,1]
的张量来提供模型。这对于移动设备来说有点太多了,所以Android决定杀死该应用程序。
当我使用batch_size = 1