Move to async training
This commit is contained in:
parent
0b1e439676
commit
d654699b52
@ -1,5 +1,6 @@
|
||||
import Foundation
|
||||
import CreateML
|
||||
import Combine
|
||||
|
||||
final class ClassifierCreator {
|
||||
|
||||
@ -11,6 +12,8 @@ final class ClassifierCreator {
|
||||
|
||||
let thumbnailDirectory: URL
|
||||
|
||||
let sessionDirectory: URL
|
||||
|
||||
let classifierUrl: URL
|
||||
|
||||
let df = DateFormatter()
|
||||
@ -23,6 +26,7 @@ final class ClassifierCreator {
|
||||
self.server = try configuration.serverUrl()
|
||||
let contentDirectory = URL(fileURLWithPath: configuration.contentFolder)
|
||||
self.imageDirectory = contentDirectory.appendingPathComponent("images")
|
||||
self.sessionDirectory = contentDirectory.appendingPathComponent("session")
|
||||
self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel")
|
||||
self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails")
|
||||
df.dateFormat = "yy-MM-dd-HH-mm-ss"
|
||||
@ -55,8 +59,8 @@ final class ClassifierCreator {
|
||||
print("[INFO] Classes: \(classes.count)")
|
||||
print("[INFO] Iterations: \(configuration.trainingIterations)")
|
||||
|
||||
try trainModel()
|
||||
try await upload(version: newVersion)
|
||||
try await trainAndSaveModel()
|
||||
try await uploadModel(version: newVersion)
|
||||
try await upload(classes: classes, lastUpdate: imagesSnapshotDate)
|
||||
await createThumbnails(changed: changedMainImages)
|
||||
print("[INFO] Done")
|
||||
@ -254,18 +258,80 @@ final class ClassifierCreator {
|
||||
|
||||
// MARK: Step 4: Train classifier
|
||||
|
||||
func trainModel() throws {
|
||||
var params = MLImageClassifier.ModelParameters(augmentation: [])
|
||||
params.maxIterations = configuration.trainingIterations
|
||||
let model: MLImageClassifier
|
||||
func trainAndSaveModel() async throws {
|
||||
let model = try await trainModelAsync()
|
||||
//let model = try trainModelSync()
|
||||
try save(model: model)
|
||||
}
|
||||
|
||||
func trainModelAsync() async throws -> MLImageClassifier {
|
||||
let params = MLImageClassifier.ModelParameters(
|
||||
maxIterations: configuration.trainingIterations,
|
||||
augmentation: [])
|
||||
|
||||
let sessionParameters = MLTrainingSessionParameters(
|
||||
sessionDirectory: sessionDirectory,
|
||||
reportInterval: 10,
|
||||
checkpointInterval: 100,
|
||||
iterations: 1000)
|
||||
|
||||
var subscriptions = [AnyCancellable]()
|
||||
|
||||
let job: MLJob<MLImageClassifier>
|
||||
do {
|
||||
model = try MLImageClassifier(
|
||||
job = try MLImageClassifier.train(
|
||||
trainingData: .labeledDirectories(at: imageDirectory),
|
||||
parameters: params,
|
||||
sessionParameters: sessionParameters)
|
||||
} catch {
|
||||
throw TrainingError.failedToCreateClassifier
|
||||
}
|
||||
|
||||
job.progress
|
||||
.publisher(for: \.fractionCompleted)
|
||||
.sink { completed in
|
||||
print(String(format: " %.1f %% completed", completed * 100), terminator: "\r")
|
||||
fflush(stdout)
|
||||
|
||||
//guard let progress = MLProgress(progress: job.progress) else {
|
||||
// return
|
||||
//}
|
||||
//if let styleLoss = progress.metrics[.styleLoss] { _ = styleLoss }
|
||||
//if let contentLoss = progress.metrics[.contentLoss] { _ = contentLoss }
|
||||
|
||||
}
|
||||
.store(in: &subscriptions)
|
||||
|
||||
return try await withCheckedThrowingContinuation { continuation in
|
||||
// Register a sink to receive the resulting model.
|
||||
job.result.sink { result in
|
||||
print("[ERROR] \(result)")
|
||||
continuation.resume(throwing: TrainingError.failedToCreateClassifier)
|
||||
} receiveValue: { model in
|
||||
// Use model
|
||||
print("[INFO] Created model")
|
||||
continuation.resume(returning: model)
|
||||
}
|
||||
.store(in: &subscriptions)
|
||||
}
|
||||
}
|
||||
|
||||
func trainModelSync() throws -> MLImageClassifier {
|
||||
let params = MLImageClassifier.ModelParameters(
|
||||
maxIterations: configuration.trainingIterations,
|
||||
augmentation: [])
|
||||
|
||||
do {
|
||||
return try MLImageClassifier(
|
||||
trainingData: .labeledDirectories(at: imageDirectory),
|
||||
parameters: params)
|
||||
} catch {
|
||||
print("[ERROR] Failed to create classifier: \(error)")
|
||||
throw TrainingError.failedToCreateClassifier(error)
|
||||
throw TrainingError.failedToCreateClassifier
|
||||
}
|
||||
}
|
||||
|
||||
private func save(model: MLImageClassifier) throws {
|
||||
|
||||
print("[INFO] Saving classifier...")
|
||||
do {
|
||||
@ -278,7 +344,7 @@ final class ClassifierCreator {
|
||||
|
||||
// MARK: Step 5: Upload classifier
|
||||
|
||||
func upload(version: Int) async throws {
|
||||
func uploadModel(version: Int) async throws {
|
||||
print("[INFO] Uploading classifier...")
|
||||
let modelData: Data
|
||||
do {
|
||||
|
@ -17,7 +17,7 @@ enum TrainingError: Error {
|
||||
case failedToGetClassifierVersion
|
||||
case invalidClassifierVersion(String)
|
||||
|
||||
case failedToCreateClassifier(Error)
|
||||
case failedToCreateClassifier
|
||||
case failedToWriteClassifier(Error)
|
||||
|
||||
case failedToGetChangedImagesList
|
||||
|
Loading…
x
Reference in New Issue
Block a user