diff --git a/Sources/ClassifierCreator.swift b/Sources/ClassifierCreator.swift index 5ee72a4..b5f7c28 100644 --- a/Sources/ClassifierCreator.swift +++ b/Sources/ClassifierCreator.swift @@ -1,5 +1,6 @@ import Foundation import CreateML +import Combine final class ClassifierCreator { @@ -11,6 +12,8 @@ final class ClassifierCreator { let thumbnailDirectory: URL + let sessionDirectory: URL + let classifierUrl: URL let df = DateFormatter() @@ -23,6 +26,7 @@ final class ClassifierCreator { self.server = try configuration.serverUrl() let contentDirectory = URL(fileURLWithPath: configuration.contentFolder) self.imageDirectory = contentDirectory.appendingPathComponent("images") + self.sessionDirectory = contentDirectory.appendingPathComponent("session") self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel") self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails") df.dateFormat = "yy-MM-dd-HH-mm-ss" @@ -55,8 +59,8 @@ final class ClassifierCreator { print("[INFO] Classes: \(classes.count)") print("[INFO] Iterations: \(configuration.trainingIterations)") - try trainModel() - try await upload(version: newVersion) + try await trainAndSaveModel() + try await uploadModel(version: newVersion) try await upload(classes: classes, lastUpdate: imagesSnapshotDate) await createThumbnails(changed: changedMainImages) print("[INFO] Done") @@ -254,18 +258,80 @@ final class ClassifierCreator { // MARK: Step 4: Train classifier - func trainModel() throws { - var params = MLImageClassifier.ModelParameters(augmentation: []) - params.maxIterations = configuration.trainingIterations - let model: MLImageClassifier + func trainAndSaveModel() async throws { + let model = try await trainModelAsync() + //let model = try trainModelSync() + try save(model: model) + } + + func trainModelAsync() async throws -> MLImageClassifier { + let params = MLImageClassifier.ModelParameters( + maxIterations: configuration.trainingIterations, + augmentation: []) + + let sessionParameters = MLTrainingSessionParameters( + sessionDirectory: sessionDirectory, + reportInterval: 10, + checkpointInterval: 100, + iterations: 1000) + + var subscriptions = [AnyCancellable]() + + let job: MLJob do { - model = try MLImageClassifier( + job = try MLImageClassifier.train( + trainingData: .labeledDirectories(at: imageDirectory), + parameters: params, + sessionParameters: sessionParameters) + } catch { + throw TrainingError.failedToCreateClassifier + } + + job.progress + .publisher(for: \.fractionCompleted) + .sink { completed in + print(String(format: " %.1f %% completed", completed * 100), terminator: "\r") + fflush(stdout) + + //guard let progress = MLProgress(progress: job.progress) else { + // return + //} + //if let styleLoss = progress.metrics[.styleLoss] { _ = styleLoss } + //if let contentLoss = progress.metrics[.contentLoss] { _ = contentLoss } + + } + .store(in: &subscriptions) + + return try await withCheckedThrowingContinuation { continuation in + // Register a sink to receive the resulting model. + job.result.sink { result in + print("[ERROR] \(result)") + continuation.resume(throwing: TrainingError.failedToCreateClassifier) + } receiveValue: { model in + // Use model + print("[INFO] Created model") + continuation.resume(returning: model) + } + .store(in: &subscriptions) + } + } + + func trainModelSync() throws -> MLImageClassifier { + let params = MLImageClassifier.ModelParameters( + maxIterations: configuration.trainingIterations, + augmentation: []) + + do { + return try MLImageClassifier( trainingData: .labeledDirectories(at: imageDirectory), parameters: params) } catch { print("[ERROR] Failed to create classifier: \(error)") - throw TrainingError.failedToCreateClassifier(error) + throw TrainingError.failedToCreateClassifier } + } + + private func save(model: MLImageClassifier) throws { print("[INFO] Saving classifier...") do { @@ -278,7 +344,7 @@ final class ClassifierCreator { // MARK: Step 5: Upload classifier - func upload(version: Int) async throws { + func uploadModel(version: Int) async throws { print("[INFO] Uploading classifier...") let modelData: Data do { diff --git a/Sources/TrainingError.swift b/Sources/TrainingError.swift index e73b03b..ea27c24 100644 --- a/Sources/TrainingError.swift +++ b/Sources/TrainingError.swift @@ -17,7 +17,7 @@ enum TrainingError: Error { case failedToGetClassifierVersion case invalidClassifierVersion(String) - case failedToCreateClassifier(Error) + case failedToCreateClassifier case failedToWriteClassifier(Error) case failedToGetChangedImagesList