diff --git a/Sources/ClassifierCreator.swift b/Sources/ClassifierCreator.swift index 50a61be..b13f4ab 100644 --- a/Sources/ClassifierCreator.swift +++ b/Sources/ClassifierCreator.swift @@ -293,7 +293,7 @@ final class ClassifierCreator { parameters: params, sessionParameters: sessionParameters) } catch { - throw TrainingError.failedToCreateClassifier("\(error)") + throw TrainingError.failedToCreateClassifier(error) } job.progress @@ -301,7 +301,6 @@ final class ClassifierCreator { .sink { completed in Swift.print(String(format: " %.1f %% completed", completed * 100), terminator: "\r") fflush(stdout) - //guard let progress = MLProgress(progress: job.progress) else { // return //} @@ -314,7 +313,12 @@ final class ClassifierCreator { return try await withCheckedThrowingContinuation { continuation in // Register a sink to receive the resulting model. job.result.sink { result in - continuation.resume(throwing: TrainingError.failedToCreateClassifier("\(result)")) + switch result { + case .finished: + break // Continuation already called with model + case .failure(let error): + continuation.resume(throwing: TrainingError.failedToCreateClassifier(error)) + } } receiveValue: { [weak self] model in // Use model self?.print(info: "Created model") @@ -334,12 +338,11 @@ final class ClassifierCreator { trainingData: .labeledDirectories(at: imageDirectory), parameters: params) } catch { - throw TrainingError.failedToCreateClassifier("\(error)") + throw TrainingError.failedToCreateClassifier(error) } } private func save(model: MLImageClassifier) throws { - print(info: "Saving classifier...") do { try model.write(to: classifierUrl) diff --git a/Sources/TrainingError.swift b/Sources/TrainingError.swift index e7a5388..0cf01fe 100644 --- a/Sources/TrainingError.swift +++ b/Sources/TrainingError.swift @@ -31,7 +31,7 @@ enum TrainingError: Error { case failedToGetClassifierVersion(Error) case invalidClassifierVersion(String) - case failedToCreateClassifier(String) + case failedToCreateClassifier(Error) case failedToWriteClassifier(Error) case failedToReadClassifierData(Error) @@ -95,8 +95,8 @@ extension TrainingError: CustomStringConvertible { return "Failed to get classifier version: \(error)" case .invalidClassifierVersion(let string): return "Invalid classifier version \(string)" - case .failedToCreateClassifier(let result): - return "Failed to create classifier: \(result)" + case .failedToCreateClassifier(let error): + return "Failed to create classifier: \(error)" case .failedToWriteClassifier(let error): return "Failed to save model to file: \(error)" case .failedToReadClassifierData(let error):