Fix continuation error

This commit is contained in:
Christoph Hagen 2023-10-24 10:38:13 +02:00
parent 874688b43f
commit 2913878b3e
2 changed files with 11 additions and 8 deletions

View File

@ -293,7 +293,7 @@ final class ClassifierCreator {
parameters: params, parameters: params,
sessionParameters: sessionParameters) sessionParameters: sessionParameters)
} catch { } catch {
throw TrainingError.failedToCreateClassifier("\(error)") throw TrainingError.failedToCreateClassifier(error)
} }
job.progress job.progress
@ -301,7 +301,6 @@ final class ClassifierCreator {
.sink { completed in .sink { completed in
Swift.print(String(format: " %.1f %% completed", completed * 100), terminator: "\r") Swift.print(String(format: " %.1f %% completed", completed * 100), terminator: "\r")
fflush(stdout) fflush(stdout)
//guard let progress = MLProgress(progress: job.progress) else { //guard let progress = MLProgress(progress: job.progress) else {
// return // return
//} //}
@ -314,7 +313,12 @@ final class ClassifierCreator {
return try await withCheckedThrowingContinuation { continuation in return try await withCheckedThrowingContinuation { continuation in
// Register a sink to receive the resulting model. // Register a sink to receive the resulting model.
job.result.sink { result in job.result.sink { result in
continuation.resume(throwing: TrainingError.failedToCreateClassifier("\(result)")) switch result {
case .finished:
break // Continuation already called with model
case .failure(let error):
continuation.resume(throwing: TrainingError.failedToCreateClassifier(error))
}
} receiveValue: { [weak self] model in } receiveValue: { [weak self] model in
// Use model // Use model
self?.print(info: "Created model") self?.print(info: "Created model")
@ -334,12 +338,11 @@ final class ClassifierCreator {
trainingData: .labeledDirectories(at: imageDirectory), trainingData: .labeledDirectories(at: imageDirectory),
parameters: params) parameters: params)
} catch { } catch {
throw TrainingError.failedToCreateClassifier("\(error)") throw TrainingError.failedToCreateClassifier(error)
} }
} }
private func save(model: MLImageClassifier) throws { private func save(model: MLImageClassifier) throws {
print(info: "Saving classifier...") print(info: "Saving classifier...")
do { do {
try model.write(to: classifierUrl) try model.write(to: classifierUrl)

View File

@ -31,7 +31,7 @@ enum TrainingError: Error {
case failedToGetClassifierVersion(Error) case failedToGetClassifierVersion(Error)
case invalidClassifierVersion(String) case invalidClassifierVersion(String)
case failedToCreateClassifier(String) case failedToCreateClassifier(Error)
case failedToWriteClassifier(Error) case failedToWriteClassifier(Error)
case failedToReadClassifierData(Error) case failedToReadClassifierData(Error)
@ -95,8 +95,8 @@ extension TrainingError: CustomStringConvertible {
return "Failed to get classifier version: \(error)" return "Failed to get classifier version: \(error)"
case .invalidClassifierVersion(let string): case .invalidClassifierVersion(let string):
return "Invalid classifier version \(string)" return "Invalid classifier version \(string)"
case .failedToCreateClassifier(let result): case .failedToCreateClassifier(let error):
return "Failed to create classifier: \(result)" return "Failed to create classifier: \(error)"
case .failedToWriteClassifier(let error): case .failedToWriteClassifier(let error):
return "Failed to save model to file: \(error)" return "Failed to save model to file: \(error)"
case .failedToReadClassifierData(let error): case .failedToReadClassifierData(let error):