import Foundation import CreateML import Combine final class ClassifierCreator { private let scenePrintRevision: Int? = 2 private let sessionProgressIterations = 10 private let sessionCheckpointIterations = 100 private let sessionIterations = 1000 private let thumbnailSize = 100 let server: URL let configuration: Configuration /// The number of images required to include a cap in training let minimumImagesPerCap: Int let imageDirectory: URL let thumbnailDirectory: URL let sessionDirectory: URL let classifierUrl: URL let urlSession = URLSession(configuration: .ephemeral) let df = DateFormatter() private func print(info: String) { Swift.print("[INFO] " + info) } // MARK: Step 1: Load configuration init(configuration: Configuration, resume: Bool) throws { self.configuration = configuration self.server = try configuration.serverUrl() self.minimumImagesPerCap = configuration.minimumImagesPerCap let contentDirectory = URL(fileURLWithPath: configuration.contentFolder) self.imageDirectory = contentDirectory.appendingPathComponent("images") self.sessionDirectory = contentDirectory.appendingPathComponent("session") self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel") self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails") df.dateFormat = "yy-MM-dd-HH-mm-ss" if !resume { try removeSessionData() } } private func removeSessionData() throws { guard FileManager.default.fileExists(atPath: sessionDirectory.path) else { return } do { try FileManager.default.removeItem(at: sessionDirectory) print(info: "Removed training session data") } catch { throw TrainingError.failedToRemoveSessionFolder(error) } } // MARK: Main function func run(skipTraining: Bool) async throws { let imagesSnapshotDate = Date() let (classes, changedImageCount, changedMainImages) = try await loadImages() guard !classes.isEmpty else { print(info: "No image classes found, exiting...") return } guard changedImageCount > 0 else { print(info: "No changed images, so no new classifier trained") try await createThumbnails(changed: changedMainImages) print(info: "Done") return } let classifierVersion = try await getClassifierVersion() let newVersion = classifierVersion + 1 print(info: "Server: \(server.path)") print(info: "Image directory: \(imageDirectory.absoluteURL.path)") print(info: "Model path: \(classifierUrl.path)") print(info: "Version: \(newVersion)") print(info: "Classes: \(classes.count)") print(info: "Iterations: \(configuration.trainingIterations)") if !skipTraining { try await trainAndSaveModel() } try await uploadModel(version: newVersion) try await upload(classes: classes, lastUpdate: imagesSnapshotDate) try await createThumbnails(changed: changedMainImages) try removeSessionData() print(info: "Done") } // MARK: Step 2: Load changed images func loadImages() async throws -> (classes: [Int], changedImageCount: Int, changedMainImages: [Int]) { do { try createFolderIfMissing(imageDirectory) } catch { throw TrainingError.mainImageFolderNotCreated(error) } let imageCounts = try await getImageCounts() .filter { id, count in // Delete caps with small counts to ensure proper training if count < self.minimumImagesPerCap { print(info: "Excluding cap \(id) from training (\(count) images)") return false } return true } let missingImageList: [CapImage] = imageCounts .filter { $0.value >= self.minimumImagesPerCap } .sorted { $0.key < $1.key } .reduce(into: []) { list, pair in let missingImagesForCap: [CapImage] = (0.. 0 ? " (\(imagesAlreadyLoad) already loaded)" : "" if filteredChangeList.isEmpty { print(info: "No changed images to load" + suffix) } else { print(info: "Loading \(filteredChangeList.count) changed images" + suffix) try await loadImagesInBatches(filteredChangeList) } let changedMainImages = changedImageList.filter { $0.image == 0 }.map { $0.cap } let classes = imageCounts.keys.sorted() // Delete any image folders not present as caps try deleteUnnecessaryImageFolders(caps: classes) return (classes, missingImageList.count + changedImageList.count, changedMainImages) } private func getImageCounts() async throws -> [Int : Int] { let data: Data do { data = try await get(server.appendingPathComponent("caps.json")) } catch { throw TrainingError.failedToGetCapDatabase(error) } do { return try JSONDecoder() .decode([Cap].self, from: data) .reduce(into: [:]) { $0[$1.id] = $1.count } } catch { throw TrainingError.failedToDecodeCapDatabase(error) } } private func deleteUnnecessaryImageFolders(caps: [Int]) throws { let validNames = caps.map { String(format: "%04d", $0) } let folders: [String] do { folders = try FileManager.default.contentsOfDirectory(atPath: imageDirectory.path) } catch { throw TrainingError.failedToGetListOfImageFolders(error) } for folder in folders { if validNames.contains(folder) { continue } // Not a valid cap folder let url = imageDirectory.appendingPathComponent(folder) do { try FileManager.default.removeItem(at: url) print(info: "Removed item in image folder: '\(folder)'") } catch { throw TrainingError.failedToRemoveImageFolder(folder, error) } } } private func capFolderUrl(base: URL, cap: Int) -> URL { base.appendingPathComponent(String(format: "%04d", cap)) } private func imageUrl(base: URL, image: CapImage) -> URL { capFolderUrl(base: base, cap: image.cap) .appendingPathComponent(String(format: "%04d-%02d.jpg", image.cap, image.image)) } private func load(image: CapImage) async throws { do { try createFolderIfMissing(imageDirectory.appendingPathComponent(String(format: "%04d", image.cap))) } catch { throw TrainingError.failedToCreateImageFolder(image.cap, error) } let url = imageUrl(base: server.appendingPathComponent("images"), image: image) let tempFile: URL, response: URLResponse do { (tempFile, response) = try await urlSession.download(from: url) } catch { throw TrainingError.failedToLoadImage(image, error) } let responseCode = (response as! HTTPURLResponse).statusCode guard responseCode == 200 else { throw TrainingError.invalidImageRequestResponse(image, responseCode) } do { let localUrl = imageUrl(base: imageDirectory, image: image) if FileManager.default.fileExists(atPath: localUrl.path) { try FileManager.default.removeItem(at: localUrl) } try FileManager.default.moveItem(at: tempFile, to: localUrl) } catch { throw TrainingError.failedToSaveImage(image, error) } } private func loadImagesInBatches(_ list: [CapImage], batchSize: Int = 100) async throws { guard !list.isEmpty else { return } print(info: "Loading \(list.count) images...") var startIndex = list.startIndex while startIndex < list.endIndex { let endIndex = min(startIndex + batchSize, list.count) let batch = Array(list[startIndex.. [CapImage] { let string: String do { string = try await get(server.appendingPathComponent("changes.txt")) } catch { throw TrainingError.failedToGetChangedImagesList(error) } return try string .components(separatedBy: "\n") .map { $0.trimmingCharacters(in: .whitespaces) } .filter { $0 != "" } .compactMap { let parts = $0.components(separatedBy: ":") guard parts.count == 3, let _ = df.date(from: parts[0]), let cap = Int(parts[1]), let image = Int(parts[2]) else { throw TrainingError.invalidEntryInChangeList($0) } return CapImage(cap: cap, image: image) } } // MARK: Step 3: Compute version func getClassifierVersion() async throws -> Int { let string: String do { string = try await get(server.appendingPathComponent("version")) } catch { throw TrainingError.failedToGetClassifierVersion(error) } guard let version = Int(string) else { throw TrainingError.invalidClassifierVersion(string) } return version } // MARK: Step 4: Train classifier func trainAndSaveModel() async throws { let model = try await trainModelAsync() //let model = try trainModelSync() try save(model: model) } func trainModelAsync() async throws -> MLImageClassifier { let params = MLImageClassifier.ModelParameters( maxIterations: configuration.trainingIterations, augmentation: [], algorithm: .transferLearning( featureExtractor: .scenePrint(revision: scenePrintRevision), classifier: .logisticRegressor)) let sessionParameters = MLTrainingSessionParameters( sessionDirectory: sessionDirectory, reportInterval: sessionProgressIterations, checkpointInterval: sessionCheckpointIterations, iterations: sessionIterations) var subscriptions = [AnyCancellable]() let job: MLJob do { job = try MLImageClassifier.train( trainingData: .labeledDirectories(at: imageDirectory), parameters: params, sessionParameters: sessionParameters) } catch { throw TrainingError.failedToCreateClassifier(error) } job.progress .publisher(for: \.fractionCompleted) .sink { completed in Swift.print(String(format: " %.1f %% completed", completed * 100), terminator: "\r") fflush(stdout) //guard let progress = MLProgress(progress: job.progress) else { // return //} //if let styleLoss = progress.metrics[.styleLoss] { _ = styleLoss } //if let contentLoss = progress.metrics[.contentLoss] { _ = contentLoss } } .store(in: &subscriptions) return try await withCheckedThrowingContinuation { continuation in // Register a sink to receive the resulting model. job.result.sink { result in switch result { case .finished: break // Continuation already called with model case .failure(let error): continuation.resume(throwing: TrainingError.failedToCreateClassifier(error)) } } receiveValue: { [weak self] model in // Use model self?.print(info: "Created model") continuation.resume(returning: model) } .store(in: &subscriptions) } } func trainModelSync() throws -> MLImageClassifier { let params = MLImageClassifier.ModelParameters( maxIterations: configuration.trainingIterations, augmentation: []) do { return try MLImageClassifier( trainingData: .labeledDirectories(at: imageDirectory), parameters: params) } catch { throw TrainingError.failedToCreateClassifier(error) } } private func save(model: MLImageClassifier) throws { print(info: "Saving classifier...") do { try model.write(to: classifierUrl) } catch { throw TrainingError.failedToWriteClassifier(error) } } // MARK: Step 5: Upload classifier func uploadModel(version: Int) async throws { print(info: "Uploading classifier...") let modelData: Data do { modelData = try Data(contentsOf: classifierUrl) } catch { throw TrainingError.failedToReadClassifierData(error) } let url = server.appendingPathComponent("classifier/\(version)") do { try await post(url: url, body: modelData) } catch { throw TrainingError.failedToUploadClassifier(error) } } // MARK: Step 6: Update classes func upload(classes: [Int], lastUpdate: Date) async throws { print(info: "Uploading trained classes...") let dateString = df.string(from: lastUpdate) let url = server.appendingPathComponent("classes/\(dateString)") let body = classes.map(String.init).joined(separator: ",").data(using: .utf8)! do { try await post(url: url, body: body) } catch { throw TrainingError.failedToUploadClassifierClasses(error) } } // MARK: Step 7: Create thumbnails func createThumbnails(changed: [Int]) async throws { try ensureMagickAvailability() do { try createFolderIfMissing(thumbnailDirectory) } catch { throw TrainingError.failedToCreateThumbnailFolder(error) } let capIdsOfMissingThumbnails = try await getMissingThumbnailIds() let all = Set(capIdsOfMissingThumbnails).union(changed) print(info: "Creating \(all.count) thumbnails...") for cap in all { try await createThumbnail(for: cap) } } func ensureMagickAvailability() throws { do { let (code, output) = try safeShell("magick --version") guard code == 0, let version = output.components(separatedBy: "ImageMagick ").dropFirst().first? .components(separatedBy: " ").first else { throw TrainingError.magickDependencyNotFound } print(info: "Using magick \(version)") } catch { throw TrainingError.magickDependencyCheckFailed(error) } } private func getMissingThumbnailIds() async throws -> [Int] { let string: String do { string = try await get(server.appendingPathComponent("thumbnails/missing")) } catch { throw TrainingError.failedToGetMissingThumbnailIds(error) } return string.components(separatedBy: ",").compactMap(Int.init) } private func createThumbnail(for cap: Int) async throws { let mainImage = CapImage(cap: cap, image: 0) let inputUrl = imageUrl(base: imageDirectory, image: mainImage) guard FileManager.default.fileExists(atPath: inputUrl.path) else { throw TrainingError.missingMainImage(cap) } let output = thumbnailDirectory.appendingPathComponent(String(format: "%04d.jpg", cap)) do { let command = "magick convert \(inputUrl.path) -quality 70% -resize \(thumbnailSize)x\(thumbnailSize) \(output.path)" let (code, output) = try safeShell(command) if code != 0 { throw TrainingError.failedToCreateThumbnail(cap, output) } } catch { throw TrainingError.failedToCreateThumbnail(cap, "\(error)") } let data: Data do { data = try Data(contentsOf: output) } catch { throw TrainingError.failedToReadCreatedThumbnail(cap, error) } do { try await post(url: server.appendingPathComponent("thumbnails/\(cap)"), body: data) } catch { throw TrainingError.failedToUploadCreatedThumbnail(cap, error) } } // MARK: Helper @discardableResult private func safeShell(_ command: String) throws -> (code: Int32, output: String) { let task = Process() let pipe = Pipe() task.standardOutput = pipe task.standardError = pipe task.arguments = ["-cl", command] task.executableURL = URL(fileURLWithPath: "/bin/zsh") task.standardInput = nil try task.run() task.waitUntilExit() let data = pipe.fileHandleForReading.readDataToEndOfFile() let output = String(data: data, encoding: .utf8)! return (task.terminationStatus, output) } private func createFolderIfMissing(_ folder: URL) throws { guard !FileManager.default.fileExists(atPath: folder.path) else { return } try FileManager.default.createDirectory(at: folder, withIntermediateDirectories: true) } // MARK: Requests private func post(url: URL, body: Data) async throws { var request = URLRequest(url: url) request.httpMethod = "POST" request.httpBody = body request.addValue(configuration.authenticationToken, forHTTPHeaderField: "key") _ = try await perform(request) } private func perform(_ request: URLRequest) async throws -> Data { let (data, response) = try await urlSession.data(for: request) let code = (response as! HTTPURLResponse).statusCode guard code == 200 else { throw TrainingError.invalidResponse(request.url!, code) } return data } private func get(_ url: URL) async throws -> Data { try await perform(URLRequest(url: url)) } private func get(_ url: URL) async throws -> String { let data: Data = try await get(url) guard let string = String(data: data, encoding: .utf8) else { throw TrainingError.invalidGetResponseData(data.count) } return string } }