react-native-tfjs自定义模型(来自示教机)不起作用

时间:2019-12-26 06:42:30

标签: react-native tensorflow

我曾尝试将tfjs-react-native用于可教导的机器自定义模型,但不起作用。

成功加载Mobilenet模型并获得预测答案似乎正确

但是使用我的自定义模型,输出是不可预测的

我的模型摘要是

enter image description here

我的自定义模型预测结果未包含预测,类名 (使用移动网络模型时会有所不同)

我的预测结果对象(没有包含预测的类名!)

enter image description here

这是我的代码

import * as tf from '@tensorflow/tfjs';
import React, {Component} from 'react';
import * as mobilenet from '@tensorflow-models/mobilenet'

import {
    View,
    SafeAreaView,
    Text,
    Image,
        } from 'react-native';
import {bundleResourceIO, decodeJpeg, fetch} from '@tensorflow/tfjs-react-native';
import * as jpeg from 'jpeg-js'

export class TfjsSample extends React.Component {
    constructor(props) {
        super(props);
        this.state = {
            isTfReady: false,
        };
    }

    async bundleResourceIOExample() {
        const image = require('./assets/img/blog_8.jpg');

        const imageAssetPath = Image.resolveAssetSource(image);
        const response = await fetch(imageAssetPath.uri, {}, { isBinary: true });

        const rawImageData = await response.arrayBuffer();
        const imageTensor =  this.imageToTensor(rawImageData);
        const predictions = await this.model.predict(imageTensor);


        console.log('------------------------------------');
        console.log('lankType : ' , predictions.rankType );
        console.log('strides : ' , predictions.strides[0] );
        console.log('shape[0] : ' , predictions.shape[0] , ", shape[1] : "  , predictions.shape[1] );
        console.log('predictions[0] : ' + predictions[0].className, ", probability : " + predictions[0].probability);
        // console.log('predictions[1] : ' + predictions[1].className, ", probability : " + predictions[1].probability);
        // console.log('predictions[2] : ' + predictions[2].className, ", probability : " + predictions[2].probability);
    }

    imageToTensor(rawImageData) {
        const TO_UINT8ARRAY = true;
        const { width, height, data } = jpeg.decode(rawImageData, TO_UINT8ARRAY);

        const buffer = new Uint8Array(224 * 224 * 3);
        let offset = 0; // offset into original data
        for (let i = 0; i < buffer.length; i += 3) {
            buffer[i] = data[offset];
            buffer[i + 1] = data[offset + 1];
            buffer[i + 2] = data[offset + 2];

            offset += 4
        }


        return tf.tensor3d(buffer, [224, 224, 3]).expandDims(0);
    }


    async componentDidMount() {
        // Wait for tf to be ready.
        await tf.ready();
        this.setState({
            isTfReady: true,
        },()=>{

            // console.log("isTfReady!");
        });

        const URL = "https://teachablemachine.withgoogle.com/models/VC64D_5W/";
        const modelUrl = URL + "model.json";
        const metadataUrl = URL + "metadata.json";


        const modelJson = require('./assets/model.json');
        const modelWeights = require('./assets/weights.bin');
        const modelMetaData = require('./assets/metadata.json');
        this.model = await tf.loadLayersModel(bundleResourceIO(modelJson, modelWeights));


        this.model.summary();


        console.log("modelWeights: " + modelWeights);
        console.log("modelJson: "+modelJson.modelTopology.config.name);

        this.bundleResourceIOExample();
    }


    render() {
        return(

            <SafeAreaView>

                {!this.state.isTfReady &&
                <Text> TF is Not Ready!
                </Text>
                }

                {this.state.isTfReady &&
                <Text> TF is Ready!
                </Text>}


            </SafeAreaView>
        )
    }
}

0 个答案:

没有答案