2022-06-24 12:13:37 +02:00
|
|
|
import Foundation
|
2021-12-20 17:52:32 +01:00
|
|
|
import Cocoa
|
|
|
|
import CreateML
|
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
struct Configuration: Codable {
|
2022-06-24 12:13:37 +02:00
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
let imageDirectory: String
|
2022-06-24 12:13:37 +02:00
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
let classifierModelPath: String
|
|
|
|
|
|
|
|
let trainingIterations: Int
|
2022-06-24 12:13:37 +02:00
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
let serverPath: String
|
2022-06-24 12:13:37 +02:00
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
let authenticationToken: String
|
2022-06-24 12:13:37 +02:00
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
init?(at url: URL) {
|
|
|
|
do {
|
|
|
|
let configData = try Data(contentsOf: url)
|
|
|
|
self = try JSONDecoder().decode(Configuration.self, from: configData)
|
|
|
|
} catch {
|
|
|
|
print("[ERROR] Failed to load configuration at \(url.absoluteURL.path): \(error)")
|
2022-06-24 12:13:37 +02:00
|
|
|
return nil
|
|
|
|
}
|
2021-12-20 17:52:32 +01:00
|
|
|
}
|
2023-01-15 11:21:47 +01:00
|
|
|
}
|
2022-06-24 12:13:37 +02:00
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
struct Cap: Codable {
|
2022-06-24 12:13:37 +02:00
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
let id: Int
|
|
|
|
|
|
|
|
let count: Int
|
|
|
|
|
|
|
|
enum CodingKeys: String, CodingKey {
|
|
|
|
case id = "i"
|
|
|
|
case count = "c"
|
2021-12-20 17:52:32 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
final class ClassifierCreator {
|
|
|
|
|
|
|
|
static let configurationFileUrl = URL(fileURLWithPath: "config.json")
|
|
|
|
|
|
|
|
let server: URL
|
|
|
|
|
|
|
|
let configuration: Configuration
|
|
|
|
|
|
|
|
let imageDirectory: URL
|
|
|
|
|
|
|
|
let classifierUrl: URL
|
|
|
|
|
|
|
|
let df = DateFormatter()
|
|
|
|
|
|
|
|
|
|
|
|
// MARK: Step 1: Load configuration
|
|
|
|
|
|
|
|
init?() {
|
|
|
|
guard let configuration = Configuration(at: ClassifierCreator.configurationFileUrl) else {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
self.configuration = configuration
|
|
|
|
guard let serverUrl = URL(string: configuration.serverPath) else {
|
|
|
|
print("[ERROR] Configuration: Invalid server path \(configuration.serverPath)")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
self.server = serverUrl
|
|
|
|
self.imageDirectory = URL(fileURLWithPath: configuration.imageDirectory)
|
|
|
|
self.classifierUrl = URL(fileURLWithPath: configuration.classifierModelPath)
|
|
|
|
df.dateFormat = "yy-MM-dd-HH-mm-ss"
|
|
|
|
}
|
|
|
|
|
|
|
|
// MARK: Main function
|
|
|
|
|
|
|
|
func run() async {
|
|
|
|
let imagesSnapshotDate = Date()
|
|
|
|
guard let (classes, changedImageCount) = await loadImages() else {
|
|
|
|
return
|
|
|
|
}
|
2021-12-20 17:52:32 +01:00
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
guard !classes.isEmpty else {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
guard changedImageCount > 0 else {
|
|
|
|
print("[INFO] No changed images, so no new classifier trained")
|
|
|
|
return
|
|
|
|
}
|
2022-06-24 12:13:37 +02:00
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
guard let classifierVersion = await getClassifierVersion() else {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
let newVersion = classifierVersion + 1
|
2022-06-24 12:13:37 +02:00
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
print("[INFO] Image directory: \(imageDirectory.absoluteURL.path)")
|
|
|
|
print("[INFO] Model path: \(configuration.classifierModelPath)")
|
|
|
|
print("[INFO] Version: \(newVersion)")
|
|
|
|
print("[INFO] Classes: \(classes.count)")
|
|
|
|
print("[INFO] Iterations: \(configuration.trainingIterations)")
|
2022-06-24 12:13:37 +02:00
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
guard trainModel() else {
|
|
|
|
return
|
|
|
|
}
|
2022-06-24 12:13:37 +02:00
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
guard await upload(version: newVersion) else {
|
|
|
|
return
|
|
|
|
}
|
2021-12-20 17:52:32 +01:00
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
guard await upload(classes: classes, lastUpdate: imagesSnapshotDate) else {
|
|
|
|
return
|
|
|
|
}
|
2021-12-20 17:52:32 +01:00
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
print("[INFO] Done")
|
|
|
|
}
|
|
|
|
|
|
|
|
// MARK: Step 2: Load changed images
|
|
|
|
|
|
|
|
func loadImages() async -> (classes: [Int], changedImages: Int)? {
|
|
|
|
guard createImageFolderIfMissing() else {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
let imageCounts = await getImageCounts()
|
|
|
|
let missingImageList: [(cap: Int, image: Int)] = imageCounts
|
|
|
|
.sorted { $0.key < $1.key }
|
|
|
|
.reduce(into: []) { list, pair in
|
|
|
|
let missingImagesForCap: [(cap: Int, image: Int)] = (0..<pair.value).compactMap { image in
|
|
|
|
let url = imageUrl(base: imageDirectory, cap: pair.key, image: image)
|
|
|
|
guard !FileManager.default.fileExists(atPath: url.path) else {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
return (cap: pair.key, image: image)
|
|
|
|
}
|
|
|
|
list.append(contentsOf: missingImagesForCap)
|
|
|
|
}
|
|
|
|
if missingImageList.isEmpty {
|
|
|
|
print("[INFO] No missing images to load")
|
|
|
|
} else {
|
|
|
|
print("[INFO] Loading \(missingImageList.count) missing images...")
|
|
|
|
}
|
|
|
|
guard await loadImages(missingImageList) else {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
let changedImageList = await getChangedImageList()
|
|
|
|
|
|
|
|
if changedImageList.isEmpty {
|
|
|
|
print("[INFO] No changed images to load")
|
|
|
|
} else {
|
|
|
|
print("[INFO] Loading \(changedImageList.count) changed images...")
|
|
|
|
}
|
|
|
|
guard await loadImages(changedImageList) else {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
let classes = imageCounts.keys.sorted()
|
|
|
|
return (classes, missingImageList.count + changedImageList.count)
|
|
|
|
}
|
|
|
|
|
|
|
|
private func createImageFolderIfMissing() -> Bool {
|
|
|
|
guard !FileManager.default.fileExists(atPath: imageDirectory.path) else {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
do {
|
|
|
|
try FileManager.default.createDirectory(at: imageDirectory, withIntermediateDirectories: true)
|
|
|
|
return true
|
|
|
|
} catch {
|
|
|
|
print("[ERROR] Failed to create image directory: \(error)")
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
private func getImageCounts() async -> [Int : Int] {
|
|
|
|
guard let data: Data = await get(server.appendingPathComponent("caps.json")) else {
|
|
|
|
return [:]
|
|
|
|
}
|
|
|
|
do {
|
|
|
|
return try JSONDecoder().decode([Cap].self, from: data)
|
|
|
|
.reduce(into: [:]) { $0[$1.id] = $1.count }
|
|
|
|
} catch {
|
|
|
|
print("[ERROR] Failed to decode cap database: \(error)")
|
|
|
|
return [:]
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
private func imageUrl(base: URL, cap: Int, image: Int) -> URL {
|
|
|
|
base.appendingPathComponent(String(format: "%04d/%04d-%02d.jpg", cap, cap, image))
|
|
|
|
}
|
|
|
|
|
|
|
|
private func loadImage(cap: Int, image: Int) async -> Bool {
|
|
|
|
let url = imageUrl(base: server.appendingPathComponent("images"), cap: cap, image: image)
|
|
|
|
let tempFile: URL, response: URLResponse
|
|
|
|
do {
|
|
|
|
(tempFile, response) = try await URLSession.shared.download(from: url)
|
|
|
|
} catch {
|
|
|
|
print("[ERROR] Failed to load image \(image) of cap \(cap): \(error)")
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
let responseCode = (response as! HTTPURLResponse).statusCode
|
|
|
|
guard responseCode == 200 else {
|
|
|
|
print("[ERROR] Failed to load image \(image) of cap \(cap): Response \(responseCode)")
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
do {
|
|
|
|
let localUrl = imageUrl(base: imageDirectory, cap: cap, image: image)
|
|
|
|
if FileManager.default.fileExists(atPath: localUrl.path) {
|
|
|
|
try FileManager.default.removeItem(at: localUrl)
|
|
|
|
}
|
|
|
|
try FileManager.default.moveItem(at: tempFile, to: localUrl)
|
|
|
|
return true
|
|
|
|
} catch {
|
|
|
|
print("[ERROR] Failed to save image \(image) of cap \(cap): \(error)")
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
private func loadImages(_ list: [(cap: Int, image: Int)]) async -> Bool {
|
|
|
|
guard !list.isEmpty else {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
var loadedImages = 0
|
|
|
|
await withTaskGroup(of: Bool.self) { group in
|
|
|
|
for (cap, image) in list {
|
|
|
|
group.addTask {
|
|
|
|
await self.loadImage(cap: cap, image: image)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for await loaded in group {
|
|
|
|
if loaded {
|
|
|
|
loadedImages += 1
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if loadedImages != list.count {
|
|
|
|
print("[ERROR] Only \(loadedImages) of \(list.count) images loaded")
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
|
|
|
|
func getChangedImageList() async -> [(cap: Int, image: Int)] {
|
|
|
|
guard let string: String = await get(server.appendingPathComponent("changes.txt")) else {
|
|
|
|
print("[ERROR] Failed to get list of changed images")
|
|
|
|
return []
|
|
|
|
}
|
|
|
|
|
|
|
|
return string
|
|
|
|
.components(separatedBy: "\n")
|
|
|
|
.filter { $0 != "" }
|
|
|
|
.compactMap {
|
|
|
|
let parts = $0.components(separatedBy: ":")
|
|
|
|
guard parts.count == 3 else {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
/*
|
|
|
|
guard let date = df.date(from: parts[0]) else {
|
|
|
|
print("[WARN] Invalid date \(parts[0]) in change list")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
*/
|
|
|
|
guard let cap = Int(parts[1]) else {
|
|
|
|
print("[WARN] Invalid cap id \(parts[1]) in change list")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
guard let image = Int(parts[2]) else {
|
|
|
|
print("[WARN] Invalid image id \(parts[2]) in change list")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
return (cap, image)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// MARK: Step 3: Compute version
|
|
|
|
|
|
|
|
func getClassifierVersion() async -> Int? {
|
|
|
|
guard let string: String = await get(server.appendingPathComponent("version")) else {
|
|
|
|
print("[ERROR] Failed to get classifier version")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
guard let version = Int(string) else {
|
|
|
|
print("[ERROR] Invalid classifier version \(string)")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
return version
|
|
|
|
}
|
|
|
|
|
|
|
|
// MARK: Step 4: Train classifier
|
|
|
|
|
|
|
|
func trainModel() -> Bool {
|
|
|
|
var params = MLImageClassifier.ModelParameters(augmentation: [])
|
|
|
|
params.maxIterations = configuration.trainingIterations
|
|
|
|
let model: MLImageClassifier
|
|
|
|
do {
|
|
|
|
model = try MLImageClassifier(
|
|
|
|
trainingData: .labeledDirectories(at: imageDirectory),
|
|
|
|
parameters: params)
|
|
|
|
} catch {
|
|
|
|
print("[ERROR] Failed to create classifier: \(error)")
|
|
|
|
return false
|
|
|
|
}
|
2021-12-20 17:52:32 +01:00
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
print("[INFO] Saving classifier...")
|
|
|
|
do {
|
|
|
|
try model.write(to: classifierUrl)
|
|
|
|
return true
|
|
|
|
} catch {
|
|
|
|
print("[ERROR] Failed to save model to file: \(error)")
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// MARK: Step 5: Upload classifier
|
|
|
|
|
|
|
|
func upload(version: Int) async -> Bool {
|
|
|
|
print("[INFO] Uploading classifier...")
|
|
|
|
let modelData: Data
|
|
|
|
do {
|
|
|
|
modelData = try Data(contentsOf: classifierUrl)
|
|
|
|
} catch {
|
|
|
|
print("[ERROR] Failed to read classifier data: \(error)")
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
return await post(
|
|
|
|
url: server.appendingPathComponent("classifier/\(version)"),
|
|
|
|
body: modelData)
|
|
|
|
}
|
|
|
|
|
|
|
|
// MARK: Step 6: Update classes
|
|
|
|
|
|
|
|
func upload(classes: [Int], lastUpdate: Date) async -> Bool {
|
|
|
|
print("[INFO] Uploading trained classes...")
|
|
|
|
let dateString = df.string(from: lastUpdate)
|
|
|
|
return await post(
|
|
|
|
url: server.appendingPathComponent("classes/\(dateString)"),
|
|
|
|
body: classes.map(String.init).joined(separator: "\n").data(using: .utf8)!,
|
|
|
|
dateKey: lastUpdate)
|
|
|
|
}
|
|
|
|
|
|
|
|
// MARK: Requests
|
|
|
|
|
|
|
|
private func post(url: URL, body: Data) async -> Bool {
|
|
|
|
var request = URLRequest(url: url)
|
|
|
|
request.httpMethod = "POST"
|
|
|
|
request.httpBody = body
|
|
|
|
request.addValue(configuration.authenticationToken, forHTTPHeaderField: "key")
|
|
|
|
return await perform(request) != nil
|
|
|
|
}
|
|
|
|
|
|
|
|
private func perform(_ request: URLRequest) async -> Data? {
|
|
|
|
let data: Data
|
|
|
|
let response: URLResponse
|
|
|
|
do {
|
|
|
|
(data, response) = try await URLSession.shared.data(for: request)
|
|
|
|
} catch {
|
|
|
|
print("[ERROR] Request failed: \(error)")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
let code = (response as! HTTPURLResponse).statusCode
|
|
|
|
guard code == 200 else {
|
|
|
|
print("[ERROR] Invalid response \(code)")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
return data
|
|
|
|
}
|
|
|
|
|
|
|
|
private func get(_ url: URL) async -> Data? {
|
|
|
|
await perform(URLRequest(url: url))
|
|
|
|
}
|
|
|
|
|
|
|
|
private func get(_ url: URL) async -> String? {
|
|
|
|
guard let data: Data = await get(url) else {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
guard let string = String(data: data, encoding: .utf8) else {
|
|
|
|
print("[ERROR] Invalid string response \(data)")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
return string
|
|
|
|
}
|
2022-06-24 12:13:37 +02:00
|
|
|
}
|
2021-12-20 17:52:32 +01:00
|
|
|
|
2023-01-15 11:21:47 +01:00
|
|
|
await ClassifierCreator().run()
|