tfjs用滑动窗口预测

时间:2019-01-02 10:36:16

标签: javascript tensorflow keras sliding-window tensorflow.js

您好,我已经在喀拉拉邦创建并培训了检测系统,用于对仪表上的模拟数字进行检测和分类。系统基于两个神经网络。第一个检测是否有数字,如果有,第二个网告诉我它是什么数字。由于我知道数字垂直放置在哪里,因此通过从左向右滑动窗口来解决搜索数字的过程。

因为我正在做Cordova混合应用程序,所以这些神经网络已移植到tensorflow.js。该解决方案出奇地好,但是非常慢。在我的电脑和Sony Xperia XZ1 Compact上,大约需要12秒。

const CROP_DIMENSION = 64


export default class NumberDetector {

    windowWidth = 7
    windowMoveBy = 1
    windowHeight = 80
    isNumberModel = null
    numberClassifyModel = null
    isInitialized = false
    localeServerUrl = null
    async run (image) {
        if (!this.isInitialized) {
            throw new OCRNotInitializedError()
        }
        /* sliding window */
        let finalNumber = ''
        let currentSegmentNumbers = []
        let currentNumber
        let timestampStart = new Date()
        for (let i = 0; i < parseInt(100 / this.windowMoveBy); i++) {
            let x1 = i * this.byPercentage(image.width, this.windowMoveBy)
            // let x2 = x1 + this.byPercentage(image.width, this.windowWidth)
            let x2 = this.byPercentage(image.width, this.windowWidth)
            let y1 = this.byPercentage(image.height, 100 - this.windowHeight)
            let y2 = this.byPercentage(image.height, this.windowHeight)
            let canvas = document.createElement('canvas')
            canvas.width = CROP_DIMENSION
            canvas.height = CROP_DIMENSION
            let context = canvas.getContext('2d')
            context.drawImage(image, x1, y1, x2, y2, 0, 0, CROP_DIMENSION, CROP_DIMENSION)
            let imageArray = context.getImageData(0, 0, CROP_DIMENSION, CROP_DIMENSION).data
            let imageArrayGray = this.convertImageArrayToGray(imageArray)
            let grayImage = this.convertGrayImageArrayToImage(imageArrayGray)
            grayImage = tf.expandDims(grayImage, 0)
            if (await this.checkIfNumber(grayImage)) {
                window.document.body.appendChild(canvas)
                currentNumber = await this.classifyNumber(grayImage)
                console.log(currentNumber)
                currentSegmentNumbers.push(currentNumber)
            } else {
                console.log('x')
                if (currentSegmentNumbers.length) {
                    finalNumber += this.getMostOccurrenceNumber(currentSegmentNumbers)
                }
                currentSegmentNumbers = []
            }
        }
        let timestampEnd = new Date()
        console.log('finalNumber', finalNumber, timestampEnd - timestampStart)
        return finalNumber
    }
    init (localeServerUrl) {
        this.localeServerUrl = localeServerUrl
        let promises = []
        promises.push(this.loadIsNumberModel())
        promises.push(this.loadNumberClassifyModel())
        return Promise.all(promises)
            .then(() => {
                this.isInitialized = true
            })
    }
    loadNumberClassifyModel () {
        let url = 'models/numbers/model.json'
        if (this.localeServerUrl) {
            url = this.localeServerUrl + '/' + url
        }
        return tf.loadModel(url)
            .then((model) => {
                this.numberClassifyModel = model
            })
    }
    loadIsNumberModel () {
        let url = 'models/yon/model.json'
        if (this.localeServerUrl) {
            url = this.localeServerUrl + '/' + url
        }
        return tf.loadModel(url)
            .then((model) => {
                this.isNumberModel = model
            })
    }
    convertImageArrayToGray (imageArray) {
        /*
         * Canvas is RGBA
         * Human can see ~30% of red, ~60% green and ~10% blue
         * Do weighted average of RGB channels and drop alpha channel
         */
        let grayScaleArray = []
        let grayValue
        for (let i = 0; i < imageArray.length; i += 4) {
            // grayValue = (imageArray[i] * 0.3) + (imageArray[i + 1] * 0.6) + (imageArray[i + 2] * 0.1)
            grayValue = (imageArray[i] * 0.299) + (imageArray[i + 1] * 0.587) + (imageArray[i + 2] * 0.114)
            grayScaleArray.push(grayValue)
        }
        return grayScaleArray
    }
    convertGrayImageArrayToImage (imageArray) {
        /**
         * Make array[64,64,1] (width, height, color channel)
         * Read array by rows and map them to 2D
         */
        let image = new Array(CROP_DIMENSION).fill(0).map(x => Array(CROP_DIMENSION).fill(0).map(x => Array(1).fill(0)))
        let row = 0
        let column = 0
        for (let i = 0; i < imageArray.length; i++) {
            image[row][column][0] = imageArray[i] / 255.0
            if (column >= (CROP_DIMENSION - 1)) {
                row++
                column = 0
            } else {
                column++
            }
        }
        return image
    }
    async checkIfNumber (croppedFloatArray) {
        let result = this.isNumberModel.predict(croppedFloatArray)
        return result.data()
            .then((data) => {
                return parseInt(this.getIndexWithMaxValue(data))
            })
    }
    async classifyNumber (croppedFloatArray) {
        let result = this.numberClassifyModel.predict(croppedFloatArray)
        return result.data()
            .then((data) => {
                return this.getIndexWithMaxValue(data)
            })
    }
    getMostOccurrenceNumber (numbers) {
        let occurrenceNumbers = {}
        for (let number of numbers) {
            if (!occurrenceNumbers[number]) {
                occurrenceNumbers[number] = 0
            }
            occurrenceNumbers[number]++
        }
        let keys = Object.keys(occurrenceNumbers)
        let max = -1
        let maxKey = -1
        let current
        for (let key of keys) {
            current = occurrenceNumbers[key]
            if (current > max) {
                max = current
                maxKey = key
            }
        }
        return maxKey
    }
    byPercentage (size, percentage) {
        return parseInt(size / 100 * percentage)
    }
    getIndexWithMaxValue (array) {
        let topVal = 0
        let indexOfTopVal = 0
        for (let index in array) {
            if (array[index] > topVal) {
                topVal = array[index]
                indexOfTopVal = index
            }
        }
        return indexOfTopVal
    }
}

我认为问题在于特征图的重复计算。我需要以某种方式给张量流整个图像并告诉它:计算特征图,然后应用滑动窗口。当前解决方案在12秒内生成输出。我真的不知道该怎么做。有人可以提供任何示例吗?

0 个答案:

没有答案