Caps-iOS/CapCollector/Classifier.swift
Christoph Hagen dceb3ca07d - Remove Xcode and Regnet classifier
- Add MobileNet classifier
- Add average color for each cap
- Add option to show average colors in mosaic
2019-07-17 11:10:07 +02:00

181 lines
5.5 KiB
Swift

//
// VisionHandler.swift
// CapFinder
//
// Created by User on 12.02.18.
// Copyright © 2018 User. All rights reserved.
//
import Foundation
import Vision
import CoreML
import UIKit
/// Notify the delegate about
protocol ClassifierDelegate {
/// Features found
func classifier(finished image: UIImage?)
/// Error handler
func classifier(error: String)
}
/// Recognise categories in images
class Classifier: Logger {
static let logToken = "[Classifier]"
static var shared = Classifier()
/// Handles errors and recognised features
var delegate: ClassifierDelegate?
// MARK: Stored predictions
private var notify = false
private var image: UIImage?
/**
Classify an image
- parameter image: The image to classify
- parameter reportingImage: Set to true, if the delegate should receive the image
*/
func recognise(image: UIImage, reportingImage: Bool = true) {
self.image = image
notify = reportingImage
if Persistence.useMobileNet {
performClassifications(model: MobileNet().model)
} else {
performClassifications(model: Squeezenet().model)
}
}
private func performClassifications(model: MLModel) {
let orientation = CGImagePropertyOrientation(image!.imageOrientation)
guard let ciImage = CIImage(image: image!) else {
report(error: "Unable to create CIImage")
return
}
DispatchQueue.global(qos: .userInitiated).async {
let handler = VNImageRequestHandler(ciImage: ciImage, orientation: orientation)
let model = try! VNCoreMLModel(for: model)
let request = VNCoreMLRequest(model: model, completionHandler: { [weak self] request, error in
guard self != nil else {
Classifier.event("Self not captured, instance deallocated?")
return
}
self!.process(request: request, error: error)
})
request.imageCropAndScaleOption = .centerCrop
do {
try handler.perform([request])
} catch {
DispatchQueue.main.async {
self.report(error: "Classification failed: \(error.localizedDescription)")
}
}
}
}
private func process(request: VNRequest, error: Error?) {
if let e = error {
report(error: "Unable to classify image: \(e.localizedDescription)")
return
}
if let result = request.results as? [VNClassificationObservation] {
let classification = dict(from: result)
process(classification: classification)
return
}
if let result = (request.results as? [VNCoreMLFeatureValueObservation])?.first?.featureValue.multiArrayValue {
let classification = dict(from: result)
process(classification: classification)
return
}
report(error: "Invalid classifier result: \(String(describing: request.results))")
}
private func process(classification: [Int : Float]) {
Cap.unsortedCaps.forEach { cap in
cap.match = classification[cap.id] ?? 0
}
Cap.hasMatches = true
// Update the count of recognized counts
Persistence.recognizedCapCount = classification.count
report()
}
/// Create a dictionary from a vision prediction
private func dict(from results: [VNClassificationObservation]) -> [Int : Float] {
let array = results.map { item -> (Int, Float) in
return (Int(item.identifier) ?? 0, item.confidence)
}
return [Int : Float](uniqueKeysWithValues: array)
}
private func dict(from results: MLMultiArray) -> [Int : Float] {
let length = results.count
let doublePtr = results.dataPointer.bindMemory(to: Double.self, capacity: length)
let doubleBuffer = UnsafeBufferPointer(start: doublePtr, count: length)
let output = Array(doubleBuffer).enumerated().map {
($0.offset + 1, Float($0.element))
}
return [Int : Float](uniqueKeysWithValues: output)
}
// MARK: Callbacks
private func report(error message: String) {
guard delegate != nil else {
error("No delegate: " + message)
return
}
DispatchQueue.main.async {
self.image = nil
self.delegate?.classifier(error: message)
}
}
private func report() {
guard delegate != nil else {
error("No delegate")
return
}
DispatchQueue.main.async {
let img = self.notify ? self.image : nil
self.image = nil
self.delegate?.classifier(finished: img)
}
}
}
extension CGImagePropertyOrientation {
/**
Converts a `UIImageOrientation` to a corresponding
`CGImagePropertyOrientation`. The cases for each
orientation are represented by different raw values.
- Tag: ConvertOrientation
*/
init(_ orientation: UIImage.Orientation) {
switch orientation {
case .up: self = .up
case .upMirrored: self = .upMirrored
case .down: self = .down
case .downMirrored: self = .downMirrored
case .left: self = .left
case .leftMirrored: self = .leftMirrored
case .right: self = .right
case .rightMirrored: self = .rightMirrored
}
}
}