202 lines
5.7 KiB
Swift
202 lines
5.7 KiB
Swift
//
|
|
// VisionHandler.swift
|
|
// CapFinder
|
|
//
|
|
// Created by User on 12.02.18.
|
|
// Copyright © 2018 User. All rights reserved.
|
|
//
|
|
|
|
import Foundation
|
|
import Vision
|
|
import CoreML
|
|
import UIKit
|
|
|
|
/// Notify the delegate about
|
|
protocol ClassifierDelegate {
|
|
|
|
/// Features found
|
|
func classifier(finished image: UIImage?)
|
|
|
|
/// Error handler
|
|
func classifier(error: String)
|
|
}
|
|
|
|
/// Recognise categories in images
|
|
class Classifier: Logger {
|
|
|
|
static let logToken = "[Classifier]"
|
|
|
|
static var shared = Classifier()
|
|
|
|
/// Handles errors and recognised features
|
|
var delegate: ClassifierDelegate?
|
|
|
|
// MARK: Stored predictions
|
|
|
|
private var predictions = [[Int : Float]]()
|
|
|
|
private var notify = false
|
|
|
|
private var image: UIImage?
|
|
|
|
private func request(for model: MLModel, name: String) -> VNCoreMLRequest {
|
|
|
|
let model = try! VNCoreMLModel(for: model)
|
|
|
|
let request = VNCoreMLRequest(model: model, completionHandler: { [weak self] request, error in
|
|
self?.process(request: request, error: error)
|
|
self?.event("Finished \(name) prediction (\(self!.predictions.count)/\(self!.requestCount))")
|
|
})
|
|
request.imageCropAndScaleOption = .centerCrop
|
|
return request
|
|
}
|
|
|
|
private var requestCount = 0
|
|
|
|
private var requests: [VNCoreMLRequest] {
|
|
var reqs = [VNCoreMLRequest]()
|
|
if Persistence.squeezenet {
|
|
reqs.append(request(for: Squeezenet().model, name: "Squeezenet"))
|
|
}
|
|
if Persistence.resnet {
|
|
reqs.append(request(for: Resnet().model, name: "Resnet"))
|
|
}
|
|
if Persistence.xcode {
|
|
reqs.append(request(for: ImageClassifier().model, name: "Xcode"))
|
|
}
|
|
requestCount = reqs.count
|
|
return reqs
|
|
}
|
|
|
|
/**
|
|
Classify an image
|
|
- parameter image: The image to classify
|
|
- parameter reportingImage: Set to true, if the delegate should receive the image
|
|
*/
|
|
func recognise(image: UIImage, reportingImage: Bool = true) {
|
|
predictions.removeAll()
|
|
self.image = image
|
|
notify = reportingImage
|
|
performClassifications()
|
|
}
|
|
|
|
private func performClassifications() {
|
|
let orientation = CGImagePropertyOrientation(image!.imageOrientation)
|
|
guard let ciImage = CIImage(image: image!) else {
|
|
report(error: "Unable to create CIImage")
|
|
return
|
|
}
|
|
|
|
DispatchQueue.global(qos: .userInitiated).async {
|
|
let handler = VNImageRequestHandler(ciImage: ciImage, orientation: orientation)
|
|
let requests = self.requests
|
|
guard requests.count > 0 else {
|
|
self.report(error: "No classifiers selected")
|
|
return
|
|
}
|
|
do {
|
|
try handler.perform(requests)
|
|
} catch {
|
|
DispatchQueue.main.async {
|
|
self.report(error: "Classification failed: \(error.localizedDescription)")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
private func process(request: VNRequest, error: Error?) {
|
|
guard let result = request.results as? [VNClassificationObservation],
|
|
result.isEmpty == false else {
|
|
report(error: "Unable to classify image: \(error?.localizedDescription ?? "No error thrown")")
|
|
return
|
|
}
|
|
let current = dict(from: result)
|
|
predictions.append(current)
|
|
|
|
if predictions.count == requestCount {
|
|
updateRecognizedCapsCount()
|
|
combine()
|
|
}
|
|
}
|
|
|
|
/// Create a dictionary from a vision prediction
|
|
private func dict(from results: [VNClassificationObservation]) -> [Int : Float] {
|
|
let array = results.map{ item -> (Int, Float) in
|
|
return (Int(item.identifier) ?? 0, item.confidence)
|
|
}
|
|
return [Int : Float](uniqueKeysWithValues: array)
|
|
}
|
|
|
|
/// Combine two predictions
|
|
private func combine() {
|
|
Cap.unsortedCaps.forEach { cap in
|
|
var result: Float = 0
|
|
for index in 0..<predictions.count {
|
|
result = max(predictions[index][cap.id] ?? 0, result)
|
|
}
|
|
cap.match = result
|
|
}
|
|
Cap.hasMatches = true
|
|
report()
|
|
}
|
|
|
|
private func updateRecognizedCapsCount() {
|
|
let recognizedCaps = predictions.map { prediction in
|
|
return prediction.count
|
|
}
|
|
Persistence.recognizedCapCount = recognizedCaps.max()!
|
|
}
|
|
|
|
// MARK: Callbacks
|
|
|
|
private func cleanup() {
|
|
predictions.removeAll()
|
|
image = nil
|
|
}
|
|
|
|
private func report(error message: String) {
|
|
guard delegate != nil else {
|
|
error("No delegate: " + message)
|
|
return
|
|
}
|
|
DispatchQueue.main.async {
|
|
self.cleanup()
|
|
self.delegate?.classifier(error: message)
|
|
}
|
|
}
|
|
|
|
private func report() {
|
|
guard delegate != nil else {
|
|
error("No delegate")
|
|
return
|
|
}
|
|
DispatchQueue.main.async {
|
|
let img = self.notify ? self.image : nil
|
|
self.cleanup()
|
|
self.delegate?.classifier(finished: img)
|
|
}
|
|
}
|
|
}
|
|
|
|
extension CGImagePropertyOrientation {
|
|
/**
|
|
Converts a `UIImageOrientation` to a corresponding
|
|
`CGImagePropertyOrientation`. The cases for each
|
|
orientation are represented by different raw values.
|
|
|
|
- Tag: ConvertOrientation
|
|
*/
|
|
init(_ orientation: UIImage.Orientation) {
|
|
switch orientation {
|
|
case .up: self = .up
|
|
case .upMirrored: self = .upMirrored
|
|
case .down: self = .down
|
|
case .downMirrored: self = .downMirrored
|
|
case .left: self = .left
|
|
case .leftMirrored: self = .leftMirrored
|
|
case .right: self = .right
|
|
case .rightMirrored: self = .rightMirrored
|
|
}
|
|
}
|
|
}
|