import Foundation import Cocoa import CreateML struct Configuration: Codable { let imageDirectory: String let classifierModelPath: String let trainingIterations: Int let serverPath: String let authenticationToken: String init?(at url: URL) { do { let configData = try Data(contentsOf: url) self = try JSONDecoder().decode(Configuration.self, from: configData) } catch { print("[ERROR] Failed to load configuration at \(url.absoluteURL.path): \(error)") return nil } } } struct Cap: Codable { let id: Int let count: Int enum CodingKeys: String, CodingKey { case id = "i" case count = "c" } } final class ClassifierCreator { static let configurationFileUrl = URL(fileURLWithPath: "config.json") let server: URL let configuration: Configuration let imageDirectory: URL let classifierUrl: URL let df = DateFormatter() // MARK: Step 1: Load configuration init?() { guard let configuration = Configuration(at: ClassifierCreator.configurationFileUrl) else { return nil } self.configuration = configuration guard let serverUrl = URL(string: configuration.serverPath) else { print("[ERROR] Configuration: Invalid server path \(configuration.serverPath)") return nil } self.server = serverUrl self.imageDirectory = URL(fileURLWithPath: configuration.imageDirectory) self.classifierUrl = URL(fileURLWithPath: configuration.classifierModelPath) df.dateFormat = "yy-MM-dd-HH-mm-ss" } // MARK: Main function func run() async { let imagesSnapshotDate = Date() guard let (classes, changedImageCount) = await loadImages() else { return } guard !classes.isEmpty else { return } guard changedImageCount > 0 else { print("[INFO] No changed images, so no new classifier trained") return } guard let classifierVersion = await getClassifierVersion() else { return } let newVersion = classifierVersion + 1 print("[INFO] Image directory: \(imageDirectory.absoluteURL.path)") print("[INFO] Model path: \(configuration.classifierModelPath)") print("[INFO] Version: \(newVersion)") print("[INFO] Classes: \(classes.count)") print("[INFO] Iterations: \(configuration.trainingIterations)") guard trainModel() else { return } guard await upload(version: newVersion) else { return } guard await upload(classes: classes, lastUpdate: imagesSnapshotDate) else { return } print("[INFO] Done") } // MARK: Step 2: Load changed images func loadImages() async -> (classes: [Int], changedImages: Int)? { guard createImageFolderIfMissing() else { return nil } let imageCounts = await getImageCounts() let missingImageList: [(cap: Int, image: Int)] = imageCounts .sorted { $0.key < $1.key } .reduce(into: []) { list, pair in let missingImagesForCap: [(cap: Int, image: Int)] = (0.. Bool { guard !FileManager.default.fileExists(atPath: imageDirectory.path) else { return true } do { try FileManager.default.createDirectory(at: imageDirectory, withIntermediateDirectories: true) return true } catch { print("[ERROR] Failed to create image directory: \(error)") return false } } private func getImageCounts() async -> [Int : Int] { guard let data: Data = await get(server.appendingPathComponent("caps.json")) else { return [:] } do { return try JSONDecoder().decode([Cap].self, from: data) .reduce(into: [:]) { $0[$1.id] = $1.count } } catch { print("[ERROR] Failed to decode cap database: \(error)") return [:] } } private func imageUrl(base: URL, cap: Int, image: Int) -> URL { base.appendingPathComponent(String(format: "%04d/%04d-%02d.jpg", cap, cap, image)) } private func loadImage(cap: Int, image: Int) async -> Bool { let url = imageUrl(base: server.appendingPathComponent("images"), cap: cap, image: image) let tempFile: URL, response: URLResponse do { (tempFile, response) = try await URLSession.shared.download(from: url) } catch { print("[ERROR] Failed to load image \(image) of cap \(cap): \(error)") return false } let responseCode = (response as! HTTPURLResponse).statusCode guard responseCode == 200 else { print("[ERROR] Failed to load image \(image) of cap \(cap): Response \(responseCode)") return false } do { let localUrl = imageUrl(base: imageDirectory, cap: cap, image: image) if FileManager.default.fileExists(atPath: localUrl.path) { try FileManager.default.removeItem(at: localUrl) } try FileManager.default.moveItem(at: tempFile, to: localUrl) return true } catch { print("[ERROR] Failed to save image \(image) of cap \(cap): \(error)") return false } } private func loadImages(_ list: [(cap: Int, image: Int)]) async -> Bool { guard !list.isEmpty else { return true } var loadedImages = 0 await withTaskGroup(of: Bool.self) { group in for (cap, image) in list { group.addTask { await self.loadImage(cap: cap, image: image) } } for await loaded in group { if loaded { loadedImages += 1 } } } if loadedImages != list.count { print("[ERROR] Only \(loadedImages) of \(list.count) images loaded") return false } return true } func getChangedImageList() async -> [(cap: Int, image: Int)] { guard let string: String = await get(server.appendingPathComponent("changes.txt")) else { print("[ERROR] Failed to get list of changed images") return [] } return string .components(separatedBy: "\n") .filter { $0 != "" } .compactMap { let parts = $0.components(separatedBy: ":") guard parts.count == 3 else { return nil } /* guard let date = df.date(from: parts[0]) else { print("[WARN] Invalid date \(parts[0]) in change list") return nil } */ guard let cap = Int(parts[1]) else { print("[WARN] Invalid cap id \(parts[1]) in change list") return nil } guard let image = Int(parts[2]) else { print("[WARN] Invalid image id \(parts[2]) in change list") return nil } return (cap, image) } } // MARK: Step 3: Compute version func getClassifierVersion() async -> Int? { guard let string: String = await get(server.appendingPathComponent("version")) else { print("[ERROR] Failed to get classifier version") return nil } guard let version = Int(string) else { print("[ERROR] Invalid classifier version \(string)") return nil } return version } // MARK: Step 4: Train classifier func trainModel() -> Bool { var params = MLImageClassifier.ModelParameters(augmentation: []) params.maxIterations = configuration.trainingIterations let model: MLImageClassifier do { model = try MLImageClassifier( trainingData: .labeledDirectories(at: imageDirectory), parameters: params) } catch { print("[ERROR] Failed to create classifier: \(error)") return false } print("[INFO] Saving classifier...") do { try model.write(to: classifierUrl) return true } catch { print("[ERROR] Failed to save model to file: \(error)") return false } } // MARK: Step 5: Upload classifier func upload(version: Int) async -> Bool { print("[INFO] Uploading classifier...") let modelData: Data do { modelData = try Data(contentsOf: classifierUrl) } catch { print("[ERROR] Failed to read classifier data: \(error)") return false } return await post( url: server.appendingPathComponent("classifier/\(version)"), body: modelData) } // MARK: Step 6: Update classes func upload(classes: [Int], lastUpdate: Date) async -> Bool { print("[INFO] Uploading trained classes...") let dateString = df.string(from: lastUpdate) return await post( url: server.appendingPathComponent("classes/\(dateString)"), body: classes.map(String.init).joined(separator: "\n").data(using: .utf8)!, dateKey: lastUpdate) } // MARK: Requests private func post(url: URL, body: Data) async -> Bool { var request = URLRequest(url: url) request.httpMethod = "POST" request.httpBody = body request.addValue(configuration.authenticationToken, forHTTPHeaderField: "key") return await perform(request) != nil } private func perform(_ request: URLRequest) async -> Data? { let data: Data let response: URLResponse do { (data, response) = try await URLSession.shared.data(for: request) } catch { print("[ERROR] Request failed: \(error)") return nil } let code = (response as! HTTPURLResponse).statusCode guard code == 200 else { print("[ERROR] Invalid response \(code)") return nil } return data } private func get(_ url: URL) async -> Data? { await perform(URLRequest(url: url)) } private func get(_ url: URL) async -> String? { guard let data: Data = await get(url) else { return nil } guard let string = String(data: data, encoding: .utf8) else { print("[ERROR] Invalid string response \(data)") return nil } return string } } await ClassifierCreator().run()