Caps-iOS/CapCollector/Classifier.swift
Christoph Hagen 2806733b71 Version 1
2019-03-15 13:19:19 +01:00

202 lines
5.7 KiB
Swift

//
// VisionHandler.swift
// CapFinder
//
// Created by User on 12.02.18.
// Copyright © 2018 User. All rights reserved.
//
import Foundation
import Vision
import CoreML
import UIKit
/// Notify the delegate about
protocol ClassifierDelegate {
/// Features found
func classifier(finished image: UIImage?)
/// Error handler
func classifier(error: String)
}
/// Recognise categories in images
class Classifier: Logger {
static let logToken = "[Classifier]"
static var shared = Classifier()
/// Handles errors and recognised features
var delegate: ClassifierDelegate?
// MARK: Stored predictions
private var predictions = [[Int : Float]]()
private var notify = false
private var image: UIImage?
private func request(for model: MLModel, name: String) -> VNCoreMLRequest {
let model = try! VNCoreMLModel(for: model)
let request = VNCoreMLRequest(model: model, completionHandler: { [weak self] request, error in
self?.process(request: request, error: error)
self?.event("Finished \(name) prediction (\(self!.predictions.count)/\(self!.requestCount))")
})
request.imageCropAndScaleOption = .centerCrop
return request
}
private var requestCount = 0
private var requests: [VNCoreMLRequest] {
var reqs = [VNCoreMLRequest]()
if Persistence.squeezenet {
reqs.append(request(for: Squeezenet().model, name: "Squeezenet"))
}
if Persistence.resnet {
reqs.append(request(for: Resnet().model, name: "Resnet"))
}
if Persistence.xcode {
reqs.append(request(for: ImageClassifier().model, name: "Xcode"))
}
requestCount = reqs.count
return reqs
}
/**
Classify an image
- parameter image: The image to classify
- parameter reportingImage: Set to true, if the delegate should receive the image
*/
func recognise(image: UIImage, reportingImage: Bool = true) {
predictions.removeAll()
self.image = image
notify = reportingImage
performClassifications()
}
private func performClassifications() {
let orientation = CGImagePropertyOrientation(image!.imageOrientation)
guard let ciImage = CIImage(image: image!) else {
report(error: "Unable to create CIImage")
return
}
DispatchQueue.global(qos: .userInitiated).async {
let handler = VNImageRequestHandler(ciImage: ciImage, orientation: orientation)
let requests = self.requests
guard requests.count > 0 else {
self.report(error: "No classifiers selected")
return
}
do {
try handler.perform(requests)
} catch {
DispatchQueue.main.async {
self.report(error: "Classification failed: \(error.localizedDescription)")
}
}
}
}
private func process(request: VNRequest, error: Error?) {
guard let result = request.results as? [VNClassificationObservation],
result.isEmpty == false else {
report(error: "Unable to classify image: \(error?.localizedDescription ?? "No error thrown")")
return
}
let current = dict(from: result)
predictions.append(current)
if predictions.count == requestCount {
updateRecognizedCapsCount()
combine()
}
}
/// Create a dictionary from a vision prediction
private func dict(from results: [VNClassificationObservation]) -> [Int : Float] {
let array = results.map{ item -> (Int, Float) in
return (Int(item.identifier) ?? 0, item.confidence)
}
return [Int : Float](uniqueKeysWithValues: array)
}
/// Combine two predictions
private func combine() {
Cap.unsortedCaps.forEach { cap in
var result: Float = 0
for index in 0..<predictions.count {
result = max(predictions[index][cap.id] ?? 0, result)
}
cap.match = result
}
Cap.hasMatches = true
report()
}
private func updateRecognizedCapsCount() {
let recognizedCaps = predictions.map { prediction in
return prediction.count
}
Persistence.recognizedCapCount = recognizedCaps.max()!
}
// MARK: Callbacks
private func cleanup() {
predictions.removeAll()
image = nil
}
private func report(error message: String) {
guard delegate != nil else {
error("No delegate: " + message)
return
}
DispatchQueue.main.async {
self.cleanup()
self.delegate?.classifier(error: message)
}
}
private func report() {
guard delegate != nil else {
error("No delegate")
return
}
DispatchQueue.main.async {
let img = self.notify ? self.image : nil
self.cleanup()
self.delegate?.classifier(finished: img)
}
}
}
extension CGImagePropertyOrientation {
/**
Converts a `UIImageOrientation` to a corresponding
`CGImagePropertyOrientation`. The cases for each
orientation are represented by different raw values.
- Tag: ConvertOrientation
*/
init(_ orientation: UIImage.Orientation) {
switch orientation {
case .up: self = .up
case .upMirrored: self = .upMirrored
case .down: self = .down
case .downMirrored: self = .downMirrored
case .left: self = .left
case .leftMirrored: self = .leftMirrored
case .right: self = .right
case .rightMirrored: self = .rightMirrored
}
}
}