Tensor(.tflite)模型推断在Swift上使用Firebase SDK返回nil

时间:2019-12-02 14:34:41

标签: swift tensorflow neural-network tensorflow-lite firebase-mlkit

前言:

我的ML(特别是NN())知识非常有限,随着时间的流逝,我真的越来越熟悉。

从本质上讲,我有一个模型可以接受输入[1,H,W,3](1个图像,高度,宽度,3个通道)和应该输出[1,H,W,2](1个图像,高度,宽度,2个通道)。这样做的想法是,我将能够从具有1个通道的输出中获取图像数据,然后将其转换为实际图像,如果存在某种“某些东西”,则该图像实际上应显示指示并进行高亮显示使用该1个颜色通道(或其他颜色通道)在输入图像中显示图像。

模型作者正在积极研究模型,因此它离完美模型还差得远。

因此,

我最初使用tensorflowlite SDK来做所有事情,但是我发现官方文档,示例和开源工作甚至都无法与Firebase SDK相提并论。另外,实际项目(当前正在测试环境中进行测试)已经使用Firebase SDK。无论如何,我能够获得某种形式的输出,但是我没有对图像进行适当的规范化,因此输出不符合预期,但至少存在某些问题。

我使用Firebase上的this指南,试图在tflite模型上进行推断。

从下面的代码中,您将看到我具有TensorFlowLite作为依赖项,但实际上我并未主动使​​用它。我有一个使用它的函数,但未调用该函数。 因此,基本上您可以忽略: parseOutputTensor,coordinateToIndex 枚举:常量

理论:

  1. 我的模型输入设置不正确。
  2. 我没有正确查看输出
  3. 在使用图像设置输入数据进行推理之前,我没有正确调整图像的大小和处理
  4. 我不知道我在做什么,我走了。D:

下面是我的代码:

import UIKit
import Firebase
import AVFoundation
import TensorFlowLite

class ViewController: UIViewController {
    var captureSesssion : AVCaptureSession!
    var cameraOutput : AVCapturePhotoOutput!
    var previewLayer : AVCaptureVideoPreviewLayer!
    @objc let device = AVCaptureDevice.default(for: .video)!
    private var previousInferenceTimeMs: TimeInterval = Date.distantPast.timeIntervalSince1970 * 1000
    private let delayBetweenInferencesMs: Double = 1000
    @IBOutlet var imageView: UIImageView!

    private var button1 : UIButton = {
        var button = UIButton()
        button.setTitle("button lol", for: .normal)
        button.translatesAutoresizingMaskIntoConstraints = false
        button.addTarget(self, action: #selector(buttonClicked), for: .touchDown)
        return button
    }()

    override func viewDidLoad() {
        super.viewDidLoad()
        startCamera()
        view.addSubview(button1)
        view.bringSubviewToFront(button1)
        button1.bottomAnchor.constraint(equalTo: view.bottomAnchor).isActive = true
        button1.titleLabel?.font = UIFont(name: "Helvetica", size: 25)
        button1.widthAnchor.constraint(equalToConstant: view.frame.width/3).isActive = true
        button1.centerXAnchor.constraint(equalTo: view.centerXAnchor).isActive = true
    }

    @objc func buttonClicked() {
        cameraPressed()
    }

    private func configureLocalModel() -> CustomLocalModel {
        guard let modelPath = Bundle.main.path(forResource: "modelName", ofType: "tflite") else { fatalError("Couldn't find the modelPath") }
        return CustomLocalModel(modelPath: modelPath)
    }

    private func createInterpreter(customLocalModel: CustomLocalModel) -> ModelInterpreter{
        return ModelInterpreter.modelInterpreter(localModel: customLocalModel)
    }

    private func setModelInputOutput() -> ModelInputOutputOptions? {
        var ioOptions : ModelInputOutputOptions
        do {
            ioOptions = ModelInputOutputOptions()
            try ioOptions.setInputFormat(index: 0, type: .float32, dimensions: [1, 512, 512, 3])
            try ioOptions.setOutputFormat(index: 0, type: .float32, dimensions: [1, 512, 512, 2])
        } catch let error as NSError {
            print("Failed to set input or output format with error: \(error.localizedDescription)")
        }
        return ioOptions
    }

    private func inputDataForInference(theImage: CGImage) -> ModelInputs?{
        let image: CGImage = theImage
        guard let context = CGContext(
            data: nil,
            width: image.width, height: image.height,
            bitsPerComponent: 8, bytesPerRow: image.width * 4,
            space: CGColorSpaceCreateDeviceRGB(),
            bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue
        ) else { fatalError("Context issues") }

        context.draw(image, in: CGRect(x: 0, y: 0, width: image.width, height: image.height))
        guard let imageData = context.data else { fatalError("Context issues") }

        let inputs : ModelInputs
        var inputData = Data()
        do {
            for row in 0 ..< 512 {
                for col in 0 ..< 512 {
                    let offset = 4 * (col * context.width + row)
                    // (Ignore offset 0, the unused alpha channel)
                    let red = imageData.load(fromByteOffset: offset+1, as: UInt8.self)
                    let green = imageData.load(fromByteOffset: offset+2, as: UInt8.self)
                    let blue = imageData.load(fromByteOffset: offset+3, as: UInt8.self)

                    // Normalize channel values to [0.0, 1.0]. This requirement varies
                    // by model. For example, some models might require values to be
                    // normalized to the range [-1.0, 1.0] instead, and others might
                    // require fixed-point values or the original bytes.
                    var normalizedRed = Float32(red) / 255.0
                    var normalizedGreen = Float32(green) / 255.0
                    var normalizedBlue = Float32(blue) / 255.0

                    // Append normalized values to Data object in RGB order.
                    let elementSize = MemoryLayout.size(ofValue: normalizedRed)
                    var bytes = [UInt8](repeating: 0, count: elementSize)
                    memcpy(&bytes, &normalizedRed, elementSize)
                    inputData.append(&bytes, count: elementSize)
                    memcpy(&bytes, &normalizedGreen, elementSize)
                    inputData.append(&bytes, count: elementSize)
                    memcpy(&bytes, &normalizedBlue, elementSize)
                    inputData.append(&bytes, count: elementSize)
                }
            }
            inputs = ModelInputs()
            try inputs.addInput(inputData)
        } catch let error {
            print("Failed to add input: \(error)")
        }
        return inputs
    }

    private func runInterpreter(interpreter: ModelInterpreter, inputs: ModelInputs, ioOptions: ModelInputOutputOptions){
        interpreter.run(inputs: inputs, options: ioOptions) { outputs, error in
            guard error == nil, let outputs = outputs else { fatalError("interpreter run error is nil or outputs is nil") }
            let output = try? outputs.output(index: 0) as? [[NSNumber]]
            print()
            print("output?[0]: \(output?[0])")
            print("output?.count: \(output?.count)")
            print("output?.description: \(output?.description)")
        }
    }


    private func gotImage(cgImage: CGImage){
        let configuredModel = configureLocalModel()
        let interpreter = createInterpreter(customLocalModel: configuredModel)
        guard let modelioOptions = setModelInputOutput() else { fatalError("modelioOptions got image error") }
        guard let modelInputs = inputDataForInference(theImage: cgImage) else { fatalError("modelInputs got image error") }
        runInterpreter(interpreter: interpreter, inputs: modelInputs, ioOptions: modelioOptions)
    }

    private func resizeImage(image: UIImage, targetSize: CGSize) -> UIImage {
        let newSize = CGSize(width: targetSize.width, height: targetSize.height)

        // This is the rect that we've calculated out and this is what is actually used below
        let rect = CGRect(x: 0, y: 0, width: targetSize.width, height: targetSize.height)

        // Actually do the resizing to the rect using the ImageContext stuff
        UIGraphicsBeginImageContextWithOptions(newSize, false, 1.0)
        image.draw(in: rect)
        let newImage = UIGraphicsGetImageFromCurrentImageContext()
        UIGraphicsEndImageContext()

        return newImage!
    }

}

extension ViewController: AVCapturePhotoCaptureDelegate{
        func startCamera(){
            captureSesssion = AVCaptureSession()
            previewLayer = AVCaptureVideoPreviewLayer(session: captureSesssion)
            captureSesssion.sessionPreset = AVCaptureSession.Preset.photo;
            cameraOutput = AVCapturePhotoOutput()
            previewLayer.frame = CGRect(x: view.frame.origin.x, y: view.frame.origin.y, width: view.frame.width, height: view.frame.height)
            previewLayer.videoGravity = AVLayerVideoGravity.resizeAspectFill

            do {
                try device.lockForConfiguration()
            } catch {
                return

            }
            device.focusMode = .continuousAutoFocus
            device.unlockForConfiguration()

            print("startcamera")

            if let input = try? AVCaptureDeviceInput(device: device) {
                if captureSesssion.canAddInput(input) {
                    captureSesssion.addInput(input)
                    if captureSesssion.canAddOutput(cameraOutput) {
                        captureSesssion.addOutput(cameraOutput)
                        view.layer.addSublayer(previewLayer)
                        captureSesssion.startRunning()
                    }
                } else {
                    print("issue here : captureSesssion.canAddInput")
                    _ = UIAlertController(title: "Your camera doesn't seem to be working :(", message: "Please make sure your camera works", preferredStyle: .alert)
                }
            } else {
                fatalError("TBPVC -> startCamera() : AVCaptureDeviceInput Error")
            }
        }
        func cameraPressed(){
            let outputFormat = [kCVPixelBufferPixelFormatTypeKey as String: kCMPixelFormat_32BGRA]
            let settings = AVCapturePhotoSettings(format: outputFormat)
            cameraOutput.capturePhoto(with: settings, delegate: self)

        }

        func photoOutput(_ output: AVCapturePhotoOutput, didFinishProcessingPhoto photo: AVCapturePhoto, error: Error?) {
            print("got image")
//            guard let cgImageFromPhoto = photo.cgImageRepresentation()?.takeRetainedValue() else { fatalError("cgImageRepresentation()?.takeRetainedValue error") }

            guard let imageData = photo.fileDataRepresentation() else {
                fatalError("Error while generating image from photo capture data.")

            }

            guard let uiImage = UIImage(data: imageData) else {
                fatalError("Unable to generate UIImage from image data.")
            }


            let tempImage = resizeImage(image: uiImage, targetSize: CGSize(width: 512, height: 512))
            // generate a corresponding CGImage
            guard let tempCgImage = tempImage.cgImage else {
                fatalError("Error generating CGImage")
            }

            gotImage(cgImage: tempCgImage)
        }

    @objc func image(_ image: UIImage, didFinishSavingWithError error: Error?, contextInfo: UnsafeRawPointer) {
        if let error = error {
            let ac = UIAlertController(title: "Save error", message: error.localizedDescription, preferredStyle: .alert)
            ac.addAction(UIAlertAction(title: "OK", style: .default))
            present(ac, animated: true)
        } else {
            let ac = UIAlertController(title: "Saved!", message: "Your altered image has been saved to your photos.", preferredStyle: .alert)
            ac.addAction(UIAlertAction(title: "OK", style: .default))
            present(ac, animated: true)
        }
    }
}

0 个答案:

没有答案