Move to async training

This commit is contained in:
Christoph Hagen 2023-10-23 14:59:25 +02:00
parent 0b1e439676
commit d654699b52
2 changed files with 76 additions and 10 deletions

View File

@ -1,5 +1,6 @@
import Foundation
import CreateML
import Combine
final class ClassifierCreator {
@ -11,6 +12,8 @@ final class ClassifierCreator {
let thumbnailDirectory: URL
let sessionDirectory: URL
let classifierUrl: URL
let df = DateFormatter()
@ -23,6 +26,7 @@ final class ClassifierCreator {
self.server = try configuration.serverUrl()
let contentDirectory = URL(fileURLWithPath: configuration.contentFolder)
self.imageDirectory = contentDirectory.appendingPathComponent("images")
self.sessionDirectory = contentDirectory.appendingPathComponent("session")
self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel")
self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails")
df.dateFormat = "yy-MM-dd-HH-mm-ss"
@ -55,8 +59,8 @@ final class ClassifierCreator {
print("[INFO] Classes: \(classes.count)")
print("[INFO] Iterations: \(configuration.trainingIterations)")
try trainModel()
try await upload(version: newVersion)
try await trainAndSaveModel()
try await uploadModel(version: newVersion)
try await upload(classes: classes, lastUpdate: imagesSnapshotDate)
await createThumbnails(changed: changedMainImages)
print("[INFO] Done")
@ -254,18 +258,80 @@ final class ClassifierCreator {
// MARK: Step 4: Train classifier
func trainModel() throws {
var params = MLImageClassifier.ModelParameters(augmentation: [])
params.maxIterations = configuration.trainingIterations
let model: MLImageClassifier
func trainAndSaveModel() async throws {
let model = try await trainModelAsync()
//let model = try trainModelSync()
try save(model: model)
}
func trainModelAsync() async throws -> MLImageClassifier {
let params = MLImageClassifier.ModelParameters(
maxIterations: configuration.trainingIterations,
augmentation: [])
let sessionParameters = MLTrainingSessionParameters(
sessionDirectory: sessionDirectory,
reportInterval: 10,
checkpointInterval: 100,
iterations: 1000)
var subscriptions = [AnyCancellable]()
let job: MLJob<MLImageClassifier>
do {
model = try MLImageClassifier(
job = try MLImageClassifier.train(
trainingData: .labeledDirectories(at: imageDirectory),
parameters: params,
sessionParameters: sessionParameters)
} catch {
throw TrainingError.failedToCreateClassifier
}
job.progress
.publisher(for: \.fractionCompleted)
.sink { completed in
print(String(format: " %.1f %% completed", completed * 100), terminator: "\r")
fflush(stdout)
//guard let progress = MLProgress(progress: job.progress) else {
// return
//}
//if let styleLoss = progress.metrics[.styleLoss] { _ = styleLoss }
//if let contentLoss = progress.metrics[.contentLoss] { _ = contentLoss }
}
.store(in: &subscriptions)
return try await withCheckedThrowingContinuation { continuation in
// Register a sink to receive the resulting model.
job.result.sink { result in
print("[ERROR] \(result)")
continuation.resume(throwing: TrainingError.failedToCreateClassifier)
} receiveValue: { model in
// Use model
print("[INFO] Created model")
continuation.resume(returning: model)
}
.store(in: &subscriptions)
}
}
func trainModelSync() throws -> MLImageClassifier {
let params = MLImageClassifier.ModelParameters(
maxIterations: configuration.trainingIterations,
augmentation: [])
do {
return try MLImageClassifier(
trainingData: .labeledDirectories(at: imageDirectory),
parameters: params)
} catch {
print("[ERROR] Failed to create classifier: \(error)")
throw TrainingError.failedToCreateClassifier(error)
throw TrainingError.failedToCreateClassifier
}
}
private func save(model: MLImageClassifier) throws {
print("[INFO] Saving classifier...")
do {
@ -278,7 +344,7 @@ final class ClassifierCreator {
// MARK: Step 5: Upload classifier
func upload(version: Int) async throws {
func uploadModel(version: Int) async throws {
print("[INFO] Uploading classifier...")
let modelData: Data
do {

View File

@ -17,7 +17,7 @@ enum TrainingError: Error {
case failedToGetClassifierVersion
case invalidClassifierVersion(String)
case failedToCreateClassifier(Error)
case failedToCreateClassifier
case failedToWriteClassifier(Error)
case failedToGetChangedImagesList