dceb3ca07d
- Add MobileNet classifier - Add average color for each cap - Add option to show average colors in mosaic
181 lines
5.5 KiB
Swift
181 lines
5.5 KiB
Swift
//
|
|
// VisionHandler.swift
|
|
// CapFinder
|
|
//
|
|
// Created by User on 12.02.18.
|
|
// Copyright © 2018 User. All rights reserved.
|
|
//
|
|
|
|
import Foundation
|
|
import Vision
|
|
import CoreML
|
|
import UIKit
|
|
|
|
/// Notify the delegate about
|
|
protocol ClassifierDelegate {
|
|
|
|
/// Features found
|
|
func classifier(finished image: UIImage?)
|
|
|
|
/// Error handler
|
|
func classifier(error: String)
|
|
}
|
|
|
|
/// Recognise categories in images
|
|
class Classifier: Logger {
|
|
|
|
static let logToken = "[Classifier]"
|
|
|
|
static var shared = Classifier()
|
|
|
|
/// Handles errors and recognised features
|
|
var delegate: ClassifierDelegate?
|
|
|
|
// MARK: Stored predictions
|
|
|
|
private var notify = false
|
|
|
|
private var image: UIImage?
|
|
|
|
/**
|
|
Classify an image
|
|
- parameter image: The image to classify
|
|
- parameter reportingImage: Set to true, if the delegate should receive the image
|
|
*/
|
|
func recognise(image: UIImage, reportingImage: Bool = true) {
|
|
self.image = image
|
|
notify = reportingImage
|
|
if Persistence.useMobileNet {
|
|
performClassifications(model: MobileNet().model)
|
|
} else {
|
|
performClassifications(model: Squeezenet().model)
|
|
}
|
|
}
|
|
|
|
private func performClassifications(model: MLModel) {
|
|
let orientation = CGImagePropertyOrientation(image!.imageOrientation)
|
|
guard let ciImage = CIImage(image: image!) else {
|
|
report(error: "Unable to create CIImage")
|
|
return
|
|
}
|
|
|
|
DispatchQueue.global(qos: .userInitiated).async {
|
|
let handler = VNImageRequestHandler(ciImage: ciImage, orientation: orientation)
|
|
let model = try! VNCoreMLModel(for: model)
|
|
|
|
let request = VNCoreMLRequest(model: model, completionHandler: { [weak self] request, error in
|
|
guard self != nil else {
|
|
Classifier.event("Self not captured, instance deallocated?")
|
|
return
|
|
}
|
|
self!.process(request: request, error: error)
|
|
})
|
|
request.imageCropAndScaleOption = .centerCrop
|
|
do {
|
|
try handler.perform([request])
|
|
} catch {
|
|
DispatchQueue.main.async {
|
|
self.report(error: "Classification failed: \(error.localizedDescription)")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
private func process(request: VNRequest, error: Error?) {
|
|
if let e = error {
|
|
report(error: "Unable to classify image: \(e.localizedDescription)")
|
|
return
|
|
}
|
|
if let result = request.results as? [VNClassificationObservation] {
|
|
let classification = dict(from: result)
|
|
process(classification: classification)
|
|
return
|
|
}
|
|
if let result = (request.results as? [VNCoreMLFeatureValueObservation])?.first?.featureValue.multiArrayValue {
|
|
let classification = dict(from: result)
|
|
process(classification: classification)
|
|
return
|
|
}
|
|
report(error: "Invalid classifier result: \(String(describing: request.results))")
|
|
|
|
|
|
}
|
|
|
|
private func process(classification: [Int : Float]) {
|
|
Cap.unsortedCaps.forEach { cap in
|
|
cap.match = classification[cap.id] ?? 0
|
|
}
|
|
Cap.hasMatches = true
|
|
|
|
// Update the count of recognized counts
|
|
Persistence.recognizedCapCount = classification.count
|
|
|
|
report()
|
|
}
|
|
|
|
/// Create a dictionary from a vision prediction
|
|
private func dict(from results: [VNClassificationObservation]) -> [Int : Float] {
|
|
let array = results.map { item -> (Int, Float) in
|
|
return (Int(item.identifier) ?? 0, item.confidence)
|
|
}
|
|
return [Int : Float](uniqueKeysWithValues: array)
|
|
}
|
|
|
|
private func dict(from results: MLMultiArray) -> [Int : Float] {
|
|
let length = results.count
|
|
let doublePtr = results.dataPointer.bindMemory(to: Double.self, capacity: length)
|
|
let doubleBuffer = UnsafeBufferPointer(start: doublePtr, count: length)
|
|
let output = Array(doubleBuffer).enumerated().map {
|
|
($0.offset + 1, Float($0.element))
|
|
}
|
|
return [Int : Float](uniqueKeysWithValues: output)
|
|
}
|
|
|
|
// MARK: Callbacks
|
|
|
|
private func report(error message: String) {
|
|
guard delegate != nil else {
|
|
error("No delegate: " + message)
|
|
return
|
|
}
|
|
DispatchQueue.main.async {
|
|
self.image = nil
|
|
self.delegate?.classifier(error: message)
|
|
}
|
|
}
|
|
|
|
private func report() {
|
|
guard delegate != nil else {
|
|
error("No delegate")
|
|
return
|
|
}
|
|
DispatchQueue.main.async {
|
|
let img = self.notify ? self.image : nil
|
|
self.image = nil
|
|
self.delegate?.classifier(finished: img)
|
|
}
|
|
}
|
|
}
|
|
|
|
extension CGImagePropertyOrientation {
|
|
/**
|
|
Converts a `UIImageOrientation` to a corresponding
|
|
`CGImagePropertyOrientation`. The cases for each
|
|
orientation are represented by different raw values.
|
|
|
|
- Tag: ConvertOrientation
|
|
*/
|
|
init(_ orientation: UIImage.Orientation) {
|
|
switch orientation {
|
|
case .up: self = .up
|
|
case .upMirrored: self = .upMirrored
|
|
case .down: self = .down
|
|
case .downMirrored: self = .downMirrored
|
|
case .left: self = .left
|
|
case .leftMirrored: self = .leftMirrored
|
|
case .right: self = .right
|
|
case .rightMirrored: self = .rightMirrored
|
|
}
|
|
}
|
|
}
|