在Android手机中集成im2txt模型

时间:2017-05-23 17:22:19

标签: android tensorflow

我是TensorFlow的新手,无法找到这些问题的解决方案。

  1. 如何为我的新数据集重新训练 im2txt 模型,以便 {{ 1}} 模型经过培训不会丢失,我的新数据集会添加到 im2txt 数据集中,以便为新标题添加标题图像(即训练数据集= MSCOCO +我的新数据集)。有人,请分享详细程序和再培训时可能遇到的问题。

  2. 我已经找到了TensorFlow教程,用于在实时数据集中运行Android中的初始V3模型,这种方法是否也适用于 MSCOCO dataset 模型即可以将其设置为实时拍摄从手机拍摄的图像。有人,请分享如何执行此操作的详细步骤。

1 个答案:

答案 0 :(得分:3)

经过数周的努力,能够在Android上运行并执行im2txt模型。 由于我从不同的博客和不同的问题和答案中找到了解决方案,感觉如果所有(最大)解决方案都在一个地方可能会有用。所以,分享所遵循的步骤。

您需要克隆tensorflow项目https://github.com/tensorflow/tensorflow/releases/tag/v1.5.0以冻结图形和更多的工具。

下载im2txt模型表单https://github.com/KranthiGV/Pretrained-Show-and-Tell-model 遵循上面链接中描述的步骤能够运行推理以在Linux桌面上生成标题在重新命名图形中的某个变量后成功完成(克服NotFoundError(请参见上面的追溯):检查点中找不到密钥lstm / basic_lstm_cell / bias错误类型)

现在我们需要冻结现有模型以获取冻结图以便在android / ios中使用

使用freeze_graph.py(tensorflow / tensorflow / blob / master / tensorflow / python / tools / freeze_graph.py)从克隆的tensorflow项目中,可以通过给出以下命令从任何模型冻结图形 命令行用法的一个示例是:

bazel build tensorflow/python/tools:freeze_graph && \
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=some_graph_def.pb \
--input_checkpoint=model.ckpt-8361242 \
--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax
--input_binary=true

我们需要提供运行模型所需的所有output_node_names,来自" Pretrained-Show-and-Tell-model \ im2txt \ im2txt \ inference_wrapper.py"我们可以将输出节点名称列为" softmax',< lstm / initial_state'和' lstm / state' 当我通过提供输出节点名称来运行冻结图形命令时,< softmax',< lstm / initial_state'和' lstm / state'得到错误" AssertionError:softmax不在图表中#34;

从答案中获取 如何冻结im2txt模型?(How to freeze an im2txt model?) 作者:Steph和Jeff Tang

当前模型ckpt.data,ckpt.index和ckpt.meta文件以及graph.pbtxt应该以推理模式加载(请参阅im2txt中的InferenceWrapper)。它使用正确的名称构建图表> softmax'< lstm / initial_state'和' lstm / state'。您保存此图形(使用相同的ckpt格式),然后您可以应用freeze_graph脚本来获取冻结模型。

在Pretrained-Show-and-Tell-model \ im2txt \ im2txt \ inference_utils \ inference_wrapper.base.py中执行此操作,只需在saver.save(sess, "model/ckpt4")之后在def _restore_fn(sess)中添加saver.restore(sess, checkpoint_path)之类的内容:。然后重建和run_inference,你将获得一个可以被iOS和Android应用程序加载的模型,这些模型可以被冻结,转换,并可选地进行memmap,

现在我运行如下命令

python tensorflow/python/tools/freeze_graph.py  \
--input_meta_graph=/tmp/ckpt4.meta \
--input_checkpoint=/tmp/ckpt4 \
--output_graph=/tmp/ckpt4_frozen.pb \
--output_node_names="softmax,lstm/initial_state,lstm/state" \
--input_binary=true

并将获得的ckpt4_frozen.pb文件加载到Android应用程序中并得到错误 " java.lang.IllegalArgumentException:没有注册OpKernel来支持Op' DecodeJpeg'与这些attrs。已注册的设备:[CPU],已注册的内核:                                                                                    [[节点:解码/ DecodeJpeg = DecodeJpegacceptable_fraction = 1,通道= 3,dct_method ="&#34 ;, fancy_upscaling =真,比= 1,try_recover_truncated =假]]"

来自https://github.com/tensorflow/tensorflow/issues/2883

由于DecodeJpeg不支持作为Android tensorflow核心的一部分,因此您需要先将其从图表中删除

bazel build tensorflow/python/tools:strip_unused && \
bazel-bin/tensorflow/python/tools/strip_unused \
--input_graph=ckpt4_frozen.pb \
--output_graph=ckpt4_frozen_stripped_graph.pb \
--input_node_names=convert_image/Cast,input_feed,lstm/state_feed\
--output_node_names=softmax,lstm/initial_state,lstm/state\
--input_binary=true

当我尝试在android中加载ckpt4_frozen_stripped_graph.pb时遇到错误,所以我跟着Jeff Tang的回答(Error using Model after using optimize_for_inference.py on frozen graph) 而不是工具:strip_unused我使用了图形转换工具

bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=/tmp/ckpt4_frozen.pb \
--out_graph=/tmp/ckpt4_frozen_transformed.pb \
--inputs="convert_image/Cast,input_feed,lstm/state_feed" \
--outputs="softmax,lstm/initial_state,lstm/state" \
--transforms='
      strip_unused_nodes(type=float, shape="1,299,299,3")
      fold_constants(ignore_errors=true) 
      fold_batch_norms
      fold_old_batch_norms' 

我可以在android上成功加载获得的ckpt4_frozen_transformed.pb。 当我为输入节点提供输入为RGB图像像素的浮点数时" convert_image / Cast"并从" lstm / initail_state"获取输出节点成功。

现在的挑战是要了解光束搜索" Pretrained-Show-and-Tell-model \ im2txt \ im2txt \ inference_utils \ caption_generator.py"同样应该在Android端实现。

如果你在

观察python脚本caption_generator.py
softmax, new_states, metadata = self.model.inference_step(sess,input_feed,state_feed)

input_feed是一个int32位数组,state_feed是一个多维浮点数组

在android方面,我尝试为" input_feed"提供int32位数组,因为没有Java API来提供多维数组,所以我将float数组提供给lstm / state_feed,因为它是先前从" LSTM / initail_state"节点

有两个错误,一个是input_fedd期望int 64bit和 " java.lang.IllegalArgumentException:-input rank(-1)< = split_dim<输入等级(1),但得到1"在lstm / state_feed。

对于第一个错误,我将input_feed feed数据类型从int32更改为int 64。

关于第二个错误,它期待二级张量。 如果你看到tensorflow java源码我们正在提供的数据类型浮点数被转换为一个张量,我们应该以这样的方式提供数据类型,即应该创建二级张量,但目前我还没有找到任何用于提供多维浮点数组的API 当我浏览tensorflow java源代码时,我发现了未公开为Android API的API,我们可以创建一个二级张量。所以我通过启用二级张量创建调用来重建libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar。(用于构建过程参考https://blog.mindorks.com/android-tensorflow-machine-learning-example-ff0e9b2654cc

现在我可以在Android上运行推理并获得图像的一个标题。但准确度非常低。 限制一个标题的原因是我没有找到一种方法来获取输出作为多维数组,这是为单个图像生成更多数量的阳离子所必需的。

String actualFilename = labelFilename.split("file:///android_asset/")[1];

vocab = new Vocabulary(assetManager.open(actualFilename));


inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
final Graph g = c.inferenceInterface.graph();

final Operation inputOperation = g.operation(inputName);
if (inputOperation == null) {
    throw new RuntimeException("Failed to find input Node '" + inputName + "'");
}
final Operation outPutOperation = g.operation(outputName);

if (outPutOperation == null) {
    throw new RuntimeException("Failed to find output Node '" + outputName + "'");
}

// The shape of the output is [N, NUM_CLASSES], where N is the batch size.
int numClasses = (int) inferenceInterface.graph().operation(outputName)
        .output(0).shape().size(1);


Log.i(TAG, "Read " + vocab.totalWords() + " labels, output layer size is " + numClasses);

// Ideally, inputSize could have been retrieved from the shape of the input operation.  Alas,
// the placeholder node for input in the graphdef typically used does not specify a shape, so it
// must be passed in as a parameter.
inputSize = inputSize;

// Pre-allocate buffers.
outputNames = new String[]{outputName + ":0"};
outputs = new float[numClasses];
inferenceInterface.feed(inputName + ":0", pixels, inputSize, inputSize, 3);


inferenceInterface.run(outputNames, runStats);


inferenceInterface.fetch(outputName + ":0", outputs);



startIm2txtBeamSearch(outputs);

//在JAVA中实现了波束搜索

private void startIm2txtBeamSearch(float[] outputs) {

        int beam_size = 1;
        //TODO:Prepare vocab ids from file
        ArrayList<Integer> vocab_ids = new ArrayList<>();
        vocab_ids.add(1);
        int vocab_end_id = 2;
        float lenth_normalization_factor = 0;
        int maxCaptionLength = 20;
        Graph g = inferenceInterface.graph();


        //node input feed
        String input_feed_node_name = "input_feed";
        Operation inputOperation = g.operation(input_feed_node_name);
        if (inputOperation == null) {
            throw new RuntimeException("Failed to find input Node '" + input_feed_node_name + "'");
        }

        String output_feed_node_name = "softmax";
        Operation outPutOperation = g.operation(output_feed_node_name);
        if (outPutOperation == null) {
            throw new RuntimeException("Failed to find output Node '" + output_feed_node_name + "'");
        }
        int output_feed_node_numClasses = (int) outPutOperation.output(0).shape().size(1);
        Log.i(TAG, "Output layer " + output_feed_node_name + ", output layer size is " + output_feed_node_numClasses);
        FloatBuffer output_feed_output = FloatBuffer.allocate(output_feed_node_numClasses);
        //float [][] output_feed_output = new float[numClasses][];

        //node state feed
        String input_state_feed_node_name = "lstm/state_feed";
        inputOperation = g.operation(input_state_feed_node_name);
        if (inputOperation == null) {
            throw new RuntimeException("Failed to find input Node '" + input_state_feed_node_name + "'");
        }
        String output_state_feed_node_name = "lstm/state";
        outPutOperation = g.operation(output_state_feed_node_name);
        if (outPutOperation == null) {
            throw new RuntimeException("Failed to find output Node '" + output_state_feed_node_name + "'");
        }
        int output_state_feed_node_numClasses = (int) outPutOperation.output(0).shape().size(1);
        Log.i(TAG, "Output layer " + output_state_feed_node_name + ", output layer size is " + output_state_feed_node_numClasses);
        FloatBuffer output_state_output = FloatBuffer.allocate(output_state_feed_node_numClasses);
        //float[][] output_state_output= new float[numClasses][];
        String[] output_nodes = new String[]{output_feed_node_name, output_state_feed_node_name};


        Caption initialBean = new Caption(vocab_ids, outputs, (float) 0.0, (float) 0.0);
        TopN partialCaptions = new TopN(beam_size);
        partialCaptions.push(initialBean);
        TopN completeCaption = new TopN(beam_size);


        captionLengthLoop:
        for (int i = maxCaptionLength; i >= 0; i--) {
            List<Caption> partialCaptionsList = new LinkedList<>(partialCaptions.extract(false));
            partialCaptions.reset();

            long[] input_feed = new long[partialCaptionsList.size()];
            float[][] state_feed = new float[partialCaptionsList.size()][];

            for (int j = 0; j < partialCaptionsList.size(); j++) {
                Caption curCaption = partialCaptionsList.get(j);
                ArrayList<Integer> senArray = curCaption.getSentence();
                input_feed[j] = senArray.get(senArray.size() - 1);
                state_feed[j] = curCaption.getState();
            }
            //feeding
            inferenceInterface.feed(input_feed_node_name, input_feed, new long[]{input_feed.length});

            inferenceInterface.feed(input_state_feed_node_name, state_feed, new long[]{state_feed.length});


            //run
            inferenceInterface.run(output_nodes, runStats);

            //fetching
            inferenceInterface.fetch(output_feed_node_name, output_feed_output);
            inferenceInterface.fetch(output_state_feed_node_name, output_state_output);

            float[] word_probabilities = new float[partialCaptionsList.size()];
            float[] new_state = new float[partialCaptionsList.size()];
            for (int k = 0; k < partialCaptionsList.size(); k++) {
                word_probabilities = output_feed_output.array();
                //output_feed_output.get(word_probabilities[k]);
                new_state = output_state_output.array();
                //output_feed_output.get(state[k]);

                // For this partial caption, get the beam_size most probable next words.
                Map<Integer, Float> word_and_probs = new LinkedHashMap<>();
                //key is index of probability; value is index = word
                for (int l = 0; l < word_probabilities.length; l++) {
                    word_and_probs.put(l, word_probabilities[l]);
                }
                //sorting
//                word_and_probs = word_and_probs.entrySet().stream()
//                        .sorted(Map.Entry.comparingByValue())
//                        .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue,(e1, e2) -> e1, LinkedHashMap::new));
                word_and_probs = MapUtil.sortByValue(word_and_probs);
                //considering first (beam size probabilities)
                LinkedHashMap<Integer, Float> final_word_and_probs = new LinkedHashMap<>();

                for (int key : word_and_probs.keySet()) {
                    final_word_and_probs.put(key, word_and_probs.get(key));
                    if (final_word_and_probs.size() == beam_size)
                        break;
                }

                for (int w : final_word_and_probs.keySet()) {
                    float p = final_word_and_probs.get(w);
                    if (p < 1e-12) {//# Avoid log(0).
                        Log.d(TAG, "p is < 1e-12");
                        continue;
                    }
                    Caption partialCaption = partialCaptionsList.get(k);
                    ArrayList<Integer> sentence = new ArrayList<>(partialCaption.getSentence());
                    sentence.add(w);
                    float logprob = (float) (partialCaption.getPorb() + Math.log(p));
                    float scroe = logprob;
                    Caption beam = new Caption(sentence, new_state, logprob, scroe);
                    if (w == vocab_end_id) {
                        completeCaption.push(beam);
                    } else {
                        partialCaptions.push(beam);
                    }
                }
                if (partialCaptions.getSize() == 0)//run out of partial candidates; happens when beam_size = 1.
                    break captionLengthLoop;
            }


            //clear buffer retrieve sub sequent output
            output_feed_output.clear();
            output_state_output.clear();
            output_feed_output = null;
            output_state_output = null;
            output_feed_output = FloatBuffer.allocate(output_feed_node_numClasses);
            output_state_output = FloatBuffer.allocate(output_state_feed_node_numClasses);
            Log.d(TAG, "----" + i + " Iteration completed----");
        }
        Log.d(TAG, "----Total Iterations completed----");
        LinkedList<Caption> completeCaptions = completeCaption.extract(true);

        for (Caption cap : completeCaptions) {

            ArrayList<Integer> wordids = cap.getSentence();
            StringBuffer caption = new StringBuffer();
            boolean isFirst = true;
            for (int word : wordids) {
                if (!isFirst)
                    caption.append(" ");
                caption.append(vocab.getWord(word));
                isFirst = false;
            }
            Log.d(TAG, "Cap score = " + Math.exp(cap.getScore()) + " and Caption is " + caption);
        }

    }

//翻译

    public class Vocabulary {
        String TAG = Vocabulary.class.getSimpleName();
        String start_word = "<S>", end_word = "</S>", unk_word = "<UNK>";
        ArrayList<String> words;

        public Vocabulary(File vocab_file) {
            loadVocabsFromFile(vocab_file);
        }

        public Vocabulary(InputStream vocab_file_stream) {
            words = readLinesFromFileAndLoadWords(new InputStreamReader(vocab_file_stream));
        }

        public Vocabulary(String vocab_file_path) {
            File vocabFile = new File(vocab_file_path);
            loadVocabsFromFile(vocabFile);
        }

        private void loadVocabsFromFile(File vocabFile) {
            try {
                this.words = readLinesFromFileAndLoadWords(new FileReader(vocabFile));
                //Log.d(TAG, "Words read from file = " + words.size());
            } catch (FileNotFoundException e) {
                e.printStackTrace();
            }
        }


        private ArrayList<String> readLinesFromFileAndLoadWords(InputStreamReader file_reader) {
            ArrayList<String> words = new ArrayList<>();
            try (BufferedReader br = new BufferedReader(file_reader)) {
                String line;
                while ((line = br.readLine()) != null) {
                    // process the line.
                    words.add(line.split(" ")[0].trim());
                }
                br.close();
                if (!words.contains(unk_word))
                    words.add(unk_word);
            } catch (IOException e) {
                e.printStackTrace();
            }

            return words;
        }

        public String getWord(int word_id) {
            if (words != null)
                if (word_id >= 0 && word_id < words.size())
                    return words.get(word_id);
            return "No word found, Maybe Vocab File not loaded";
        }

        public int totalWords() {
            if (words != null)
                return words.size();
            return 0;
        }
    }

// MapUtil

public class MapUtil {


    public static <K, V extends Comparable<? super V>> Map<K, V> sortByValue(Map<K, V> map) {
        List<Map.Entry<K, V>> list = new ArrayList<>(map.entrySet());
        list.sort(new Comparator<Map.Entry<K, V>>() {
            @Override
            public int compare(Map.Entry<K, V> o1, Map.Entry<K, V> o2) {
                if (o1.getValue() instanceof Float && o2.getValue() instanceof Float) {
                    Float o1Float = (Float) o1.getValue();
                    Float o2Float = (Float) o2.getValue();

                    return o1Float >= o2Float ? -1 : 1;
                }
                return 0;
            }
        });

        Map<K, V> result = new LinkedHashMap<>();
        for (Map.Entry<K, V> entry : list) {
            result.put(entry.getKey(), entry.getValue());
        }

        return result;
    }

}

//标题

    public class Caption implements Comparable<Caption> {

        private ArrayList<Integer> sentence;
        private float[] state;
        private float porb;
        private float score;

        public Caption(ArrayList<Integer> sentence, float[] state, float porb, float score) {
            this.sentence = sentence;
            this.state = state;
            this.porb = porb;
            this.score = score;
        }

        public ArrayList<Integer> getSentence() {
            return sentence;
        }

        public void setSentence(ArrayList<Integer> sentence) {
            this.sentence = sentence;
        }

        public float[] getState() {
            return state;
        }

        public void setState(float[] state) {
            this.state = state;
        }

        public float getPorb() {
            return porb;
        }

        public void setPorb(float porb) {
            this.porb = porb;
        }

        public float getScore() {
            return score;
        }

        public void setScore(float score) {
            this.score = score;
        }

        @Override
        public int compareTo(@NonNull Caption oc) {
            if (score == oc.score)
                return 0;
            if (score < oc.score)
                return -1;
            else
                return 1;
        }
    }

// TOPN

 public class TopN {

    //Maintains the top n elements of an incrementally provided set.
    int n;
    LinkedList<Caption> data;


    public TopN(int n) {
        this.n = n;
        this.data = new LinkedList<>();
    }

    public int getSize() {
        if (data != null)
            return data.size();
        return 0;
    }

    //Pushes a new element
    public void push(Caption x) {
        if (data != null) {
            if (getSize() < n) {
                data.add(x);
            } else {
                data.removeLast();
                data.add(x);
            }
        }
    }

    //Extracts all elements from the TopN. This is a destructive operation.
    //The only   method that  can be  called immediately after extract() is reset().
    //Args:
    //sort: Whether to return the elements in descending  sorted order.
    //Returns: A list of data; the top   n elements provided to  the set.

    public LinkedList<Caption> extract(boolean sort) {
        if (sort) {
            Collections.sort(data);
        }
        return data;
    }

    //Returns the TopN to an empty state.
    public void reset() {
        if (data != null) data.clear();
    }

}

即使准确度非常低,我也会分享这个,因为对于某些人来说加载show并告诉android中的模型可能会有用。