在android中部署tensorflow模型获取inputize不匹配错误

时间:2018-04-04 19:23:24

标签: android python tensorflow deep-learning conv-neural-network

我一直在关注以下tensorflow教程,在python中进行培训并在android中进行部署。

培训代码:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
# Just disables the warning, doesn't enable AVX/FMA
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
#matplotlib inline
plt.style.use('ggplot')


def read_data(file_path):
    column_names = ['user-id', 'activity', 'timestamp', 'x-axis', 'y-axis', 'z-axis']
    data = pd.read_csv(file_path, header=None, names=column_names)
    return data


def feature_normalize(dataset):
    mu = np.mean(dataset, axis=0)
    sigma = np.std(dataset, axis=0)
    return (dataset - mu) / sigma


def plot_axis(ax, x, y, title):
    ax.plot(x, y)
    ax.set_title(title)
    ax.xaxis.set_visible(False)
    ax.set_ylim([min(y) - np.std(y), max(y) + np.std(y)])
    ax.set_xlim([min(x), max(x)])
    ax.grid(True)


def plot_activity(activity, data):
    fig, (ax0, ax1, ax2) = plt.subplots(nrows=3, figsize=(15, 10), sharex=True)
    plot_axis(ax0, data['timestamp'], data['x-axis'], 'x-axis')
    plot_axis(ax1, data['timestamp'], data['y-axis'], 'y-axis')
    plot_axis(ax2, data['timestamp'], data['z-axis'], 'z-axis')
    plt.subplots_adjust(hspace=0.2)
    fig.suptitle(activity)
    plt.subplots_adjust(top=0.90)
    plt.show()


def windows(data, size):
    start = 0
    while start < data.count():
        yield int(start), int(start + size)
        start += (size / 2)


def segment_signal(data, window_size=90):
    segments = np.empty((0, window_size, 3))
    labels = np.empty((0))
    for (start, end) in windows(data['timestamp'], window_size):
        x = data["x-axis"][start:end]
        y = data["y-axis"][start:end]
        z = data["z-axis"][start:end]
        if (len(dataset['timestamp'][start:end]) == window_size):
            segments = np.vstack([segments, np.dstack([x, y, z])])
            labels = np.append(labels, stats.mode(data["activity"][start:end])[0][0])
    return segments, labels


def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)


def bias_variable(shape):
    initial = tf.constant(0.0, shape=shape)
    return tf.Variable(initial)


def depthwise_conv2d(x, W):
    return tf.nn.depthwise_conv2d(x, W, [1, 1, 1, 1], padding='VALID')


def apply_depthwise_conv(x, kernel_size, num_channels, depth):
    weights = weight_variable([1, kernel_size, num_channels, depth])
    biases = bias_variable([depth * num_channels])
    return tf.nn.relu(tf.add(depthwise_conv2d(x, weights), biases))


def apply_max_pool(x, kernel_size, stride_size):
    return tf.nn.max_pool(x, ksize=[1, 1, kernel_size, 1],
                          strides=[1, 1, stride_size, 1], padding='VALID')


dataset = read_data('WISDM_ar_v1.1_raw_copy.txt')

dataset.dropna(axis=0, how='any', inplace= True)
dataset.drop_duplicates(['user-id','activity','timestamp', 'x-axis', 'y-axis', 'z-axis'], keep= 'first', inplace= True)

dataset['x-axis'] = feature_normalize(dataset['x-axis'])
dataset['y-axis'] = feature_normalize(dataset['y-axis'])
dataset['z-axis'] = feature_normalize(dataset['z-axis'])


# for activity in np.unique(dataset["activity"]):
#     subset = dataset[dataset["activity"] == activity][:180]
#     plot_activity(activity,subset)

segments, labels = segment_signal(dataset)
labels = np.asarray(pd.get_dummies(labels), dtype=np.int8)
reshaped_segments = segments.reshape(len(segments), 1, 90, 3)

train_test_split = np.random.rand(len(reshaped_segments)) < 0.70
train_x = reshaped_segments[train_test_split]
train_y = labels[train_test_split]
test_x = reshaped_segments[~train_test_split]
test_y = labels[~train_test_split]

input_height = 1 # 1-Dimensional convulotion
input_width = 90 #window
num_labels = 6 #output labels
num_channels = 3 #input columns

batch_size = 10
kernel_size = 60
depth = 60
num_hidden = 1000

learning_rate = 0.0001
training_epochs = 1#8

total_batches = train_x.shape[0] # batch_size

X = tf.placeholder(tf.float32, shape=[None,input_height,input_width,num_channels])
Y = tf.placeholder(tf.float32, shape=[None,num_labels])

c = apply_depthwise_conv(X,kernel_size,num_channels,depth)
p = apply_max_pool(c,20,2)
c = apply_depthwise_conv(p,6,depth*num_channels,depth//10)

shape = c.get_shape().as_list()
c_flat = tf.reshape(c, [-1, shape[1] * shape[2] * shape[3]])

f_weights_l1 = weight_variable([shape[1] * shape[2] * depth * num_channels * (depth//10), num_hidden])
f_biases_l1 = bias_variable([num_hidden])
f = tf.nn.tanh(tf.add(tf.matmul(c_flat, f_weights_l1),f_biases_l1))

out_weights = weight_variable([num_hidden, num_labels])
out_biases = bias_variable([num_labels])
y_ = tf.nn.softmax(tf.matmul(f, out_weights) + out_biases,name="y_")

loss = -tf.reduce_sum(Y * tf.log(y_))
optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate).minimize(loss)

correct_prediction = tf.equal(tf.argmax(y_,1), tf.argmax(Y,1)) #difference between correct output and expected output
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

cost_history = np.empty(shape=[1], dtype=float)


with tf.Session() as session:
    tf.global_variables_initializer().run()
    for epoch in range(training_epochs):
        for b in range(total_batches):
            offset = (b * batch_size) % (train_y.shape[0] - batch_size)
            batch_x = train_x[offset:(offset + batch_size), :, :, :]
            batch_y = train_y[offset:(offset + batch_size), :]
            _, c = session.run([optimizer, loss], feed_dict={X: batch_x, Y: batch_y})
            cost_history = np.append(cost_history, c)
        print "Epoch: ", epoch, " Training Loss: ", c, " Training Accuracy: ",
        session.run(accuracy, feed_dict={X: train_x, Y: train_y})
    print "Testing Accuracy:", session.run(accuracy, feed_dict={X: test_x, Y: test_y})

    saver = tf.train.Saver()
    tf.train.write_graph(session.graph_def, '.', './model/har.pbtxt')
    saver.save(session, save_path="./model/har.ckpt")

    # File_writer = tf.summary.FileWriter('./graph', session.graph)

冻结模型:

from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib
import tensorflow as tf

freeze_graph.freeze_graph(input_graph = "./model/har.pbtxt",  input_saver = "",
             input_binary = False, input_checkpoint = "./model/har.ckpt", output_node_names = "y_",
             restore_op_name = "save/restore_all", filename_tensor_name = "save/Const:0",
             output_graph = "./model/frozen_har.pb", clear_devices = True, initializer_nodes = "")

input_graph_def = tf.GraphDef()
with tf.gfile.Open("./model/frozen_har.pb", "r") as f:
    data = f.read()
    input_graph_def.ParseFromString(data)

output_graph_def = optimize_for_inference_lib.optimize_for_inference(
        input_graph_def,
        ["input"],
        ["y_"],
        tf.float32.as_datatype_enum)

f = tf.gfile.FastGFile("./model/optimized_har.pb", "w")
f.write(output_graph_def.SerializeToString())

当我尝试冻结模型时出现错误:

2018-04-05 10:00:50.604122: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
Converted 8 variables to const ops.
Traceback (most recent call last):
  File "/home/aawesh/PycharmProjects/Human-Activity-Recognition-using-CNN/Freeze.py", line 19, in <module>
    tf.float32.as_datatype_enum)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/optimize_for_inference_lib.py", line 109, in optimize_for_inference
    placeholder_type_enum)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/strip_unused_lib.py", line 83, in strip_unused
    raise KeyError("The following input nodes were not found: %s\n" % not_found)
KeyError: "The following input nodes were not found: set(['input'])\n"

Process finished with exit code 1

Android部分:

MainActivity.java

package io.github.aqibsaeed.activityrecognition;

import android.content.Context;
import android.hardware.Sensor;
import android.hardware.SensorEvent;
import android.hardware.SensorEventListener;
import android.hardware.SensorManager;
import android.os.Bundle;
import android.support.v7.app.AppCompatActivity;
import android.widget.TextView;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;

public class MainActivity extends AppCompatActivity implements SensorEventListener {

    private final int N_SAMPLES = 90;
    private static List<Float> x;
    private static List<Float> y;
    private static List<Float> z;
    private static List<Float> input_signal;
    private SensorManager mSensorManager;
    private Sensor mAccelerometer;
    private ActivityInference activityInference;

    private TextView downstairsTextView;
    private TextView joggingTextView;
    private TextView sittingTextView;
    private TextView standingTextView;
    private TextView upstairsTextView;
    private TextView walkingTextView;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        x = new ArrayList<Float>();
        y = new ArrayList<Float>();
        z = new ArrayList<Float>();
        input_signal = new ArrayList<Float>();

        downstairsTextView = (TextView)findViewById(R.id.downstairs_prob);
        joggingTextView = (TextView)findViewById(R.id.jogging_prob);
        sittingTextView = (TextView)findViewById(R.id.sitting_prob);
        standingTextView = (TextView)findViewById(R.id.standing_prob);
        upstairsTextView = (TextView)findViewById(R.id.upstairs_prob);
        walkingTextView = (TextView)findViewById(R.id.walking_prob);

        mSensorManager = (SensorManager) getSystemService(Context.SENSOR_SERVICE);
        mAccelerometer = mSensorManager.getDefaultSensor(Sensor.TYPE_ACCELEROMETER);
        mSensorManager.registerListener(this, mAccelerometer , SensorManager.SENSOR_DELAY_FASTEST);
        activityInference = new ActivityInference(getApplicationContext());
    }

    protected void onPause() {
        super.onPause();
        mSensorManager.unregisterListener(this);
    }

    protected void onResume() {
        super.onResume();
        mSensorManager.registerListener(this, mAccelerometer, SensorManager.SENSOR_DELAY_FASTEST);
    }

    @Override
    public void onSensorChanged(SensorEvent event) {
        activityPrediction();
        x.add(event.values[0]);
        y.add(event.values[1]);
        z.add(event.values[2]);
    }

    @Override
    public void onAccuracyChanged(Sensor sensor, int i) {

    }

    private void activityPrediction()
    {
        if(x.size() == N_SAMPLES && y.size() == N_SAMPLES && z.size() == N_SAMPLES) {
            // Mean normalize the signal
            normalize();

            // Copy all x,y and z values to one array of shape N_SAMPLES*3
            input_signal.addAll(x); input_signal.addAll(y); input_signal.addAll(z);

            // Perform inference using Tensorflow
            float[] results = activityInference.getActivityProb(toFloatArray(input_signal));

            downstairsTextView.setText(Float.toString(round(results[0],2)));
            joggingTextView.setText(Float.toString(round(results[1],2)));
            sittingTextView.setText(Float.toString(round(results[2],2)));
            standingTextView.setText(Float.toString(round(results[3],2)));
            upstairsTextView.setText(Float.toString(round(results[4],2)));
            walkingTextView.setText(Float.toString(round(results[5],2)));

            // Clear all the values
            x.clear(); y.clear(); z.clear(); input_signal.clear();
        }
    }

    private float[] toFloatArray(List<Float> list)
    {
        int i = 0;
        float[] array = new float[list.size()];

        for (Float f : list) {
            array[i++] = (f != null ? f : Float.NaN);
        }
        return array;
    }

    private void normalize()
    {
        float x_m = 0.662868f; float y_m = 7.255639f; float z_m = 0.411062f;
        float x_s = 6.849058f; float y_s = 6.746204f; float z_s = 4.754109f;

        for(int i = 0; i < N_SAMPLES; i++){
            x.set(i,((x.get(i) - x_m)/x_s));
            y.set(i,((y.get(i) - y_m)/y_s));
            z.set(i,((z.get(i) - z_m)/z_s));
        }
    }

    public static float round(float d, int decimalPlace) {
        BigDecimal bd = new BigDecimal(Float.toString(d));
        bd = bd.setScale(decimalPlace, BigDecimal.ROUND_HALF_UP);
        return bd.floatValue();
    }
}

ActivityInference.java

package io.github.aqibsaeed.activityrecognition;

import android.content.Context;
import android.content.res.AssetManager;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;


public class ActivityInference {
//    static {
//        System.loadLibrary("tensorflow_inference");
//    }

    private static ActivityInference activityInferenceInstance;
    private TensorFlowInferenceInterface inferenceInterface;
//    private static final String MODEL_FILE = "file:///android_asset/optimized_har.pb";
    private static final String MODEL_FILE = "file:///android_asset/optimized_har_1.pb";
    private static final String INPUT_NODE = "input";
    private static final String[] OUTPUT_NODES = {"y_"};
    private static final String OUTPUT_NODE = "y_";
    private static final long[] INPUT_SIZE = {1,270};
    private static final int OUTPUT_SIZE = 6;
    private static AssetManager assetManager;

    public static ActivityInference getInstance(final Context context)
    {
        if (activityInferenceInstance == null)
        {
            activityInferenceInstance = new ActivityInference(context);
        }
        return activityInferenceInstance;
    }

    public ActivityInference(final Context context) {
        this.assetManager = context.getAssets();
        inferenceInterface = new TensorFlowInferenceInterface(assetManager, MODEL_FILE);
    }

    public float[] getActivityProb(float[] input_signal)
    {
        float[] result = new float[OUTPUT_SIZE];
        inferenceInterface.feed(INPUT_NODE,input_signal,INPUT_SIZE);
        inferenceInterface.run(OUTPUT_NODES);
        inferenceInterface.fetch(OUTPUT_NODE,result);
        //Downstairs    Jogging   Sitting   Standing    Upstairs    Walking
        return result;
    }
}

Implementation of CNN for Human Activity Recognition

Deploying the trained model in android

我的问题是:

  1. 在部署时,它使用名为&#39; input&#39;的节点。但是我们在训练时没有这样的张量名称。冻结图形时会出现问题。我该如何解决?

  2. 在训练模型时,它使用形状4维的输入张量。但是,在android中我们只提供一个列表。它是如何工作的?

  3. 请帮帮我。我被困在这一段很长一段时间了。

    Similar question has been asked here 但它不是重复的。我也没有在那里找到答案。

    我的尝试:我以某种方式设法释放模型并在android上运行它。我收到了以下错误。我仍然需要帮助来冻结模型:

    04-05 10:16:43.988 13148-13148/tictactoe.com.tictacoe E/AndroidRuntime: FATAL EXCEPTION: main
                                                                            Process: tictactoe.com.tictacoe, PID: 13148
                                                                            java.lang.IllegalArgumentException: input must be 4-dimensional[1,270]
                                                                                 [[Node: depthwise = DepthwiseConv2dNative[T=DT_FLOAT, data_format="NHWC", dilations=[1, 1, 1, 1], padding="VALID", strides=[1, 1, 1, 1], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_input_0_0, Variable)]]
                                                                                at org.tensorflow.Session.run(Native Method)
                                                                                at org.tensorflow.Session.access$100(Session.java:48)
                                                                                at org.tensorflow.Session$Runner.runHelper(Session.java:298)
                                                                                at org.tensorflow.Session$Runner.run(Session.java:248)
                                                                                at org.tensorflow.contrib.android.TensorFlowInferenceInterface.run(TensorFlowInferenceInterface.java:230)
                                                                                at org.tensorflow.contrib.android.TensorFlowInferenceInterface.run(TensorFlowInferenceInterface.java:197)
                                                                                at org.tensorflow.contrib.android.TensorFlowInferenceInterface.run(TensorFlowInferenceInterface.java:187)
                                                                                at io.github.aqibsaeed.activityrecognition.ActivityInference.getActivityProb(ActivityInference.java:43)
                                                                                at io.github.aqibsaeed.activityrecognition.MainActivity.activityPrediction(MainActivity.java:89)
                                                                                at io.github.aqibsaeed.activityrecognition.MainActivity.onSensorChanged(MainActivity.java:68)
                                                                                at android.hardware.SystemSensorManager$SensorEventQueue.dispatchSensorEvent(SystemSensorManager.java:503)
                                                                                at android.os.MessageQueue.nativePollOnce(Native Method)
                                                                                at android.os.MessageQueue.next(MessageQueue.java:143)
                                                                                at android.os.Looper.loop(Looper.java:130)
                                                                                at android.app.ActivityThread.main(ActivityThread.java:6895)
                                                                                at java.lang.reflect.Method.invoke(Native Method)
                                                                                at java.lang.reflect.Method.invoke(Method.java:372)
                                                                                at com.android.internal.os.ZygoteInit$MethodAndArgsCaller.run(ZygoteInit.java:1404)
                                                                                at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:1199)
    

0 个答案:

没有答案