Move to async training
This commit is contained in:
parent
0b1e439676
commit
d654699b52
@ -1,5 +1,6 @@
|
|||||||
import Foundation
|
import Foundation
|
||||||
import CreateML
|
import CreateML
|
||||||
|
import Combine
|
||||||
|
|
||||||
final class ClassifierCreator {
|
final class ClassifierCreator {
|
||||||
|
|
||||||
@ -11,6 +12,8 @@ final class ClassifierCreator {
|
|||||||
|
|
||||||
let thumbnailDirectory: URL
|
let thumbnailDirectory: URL
|
||||||
|
|
||||||
|
let sessionDirectory: URL
|
||||||
|
|
||||||
let classifierUrl: URL
|
let classifierUrl: URL
|
||||||
|
|
||||||
let df = DateFormatter()
|
let df = DateFormatter()
|
||||||
@ -23,6 +26,7 @@ final class ClassifierCreator {
|
|||||||
self.server = try configuration.serverUrl()
|
self.server = try configuration.serverUrl()
|
||||||
let contentDirectory = URL(fileURLWithPath: configuration.contentFolder)
|
let contentDirectory = URL(fileURLWithPath: configuration.contentFolder)
|
||||||
self.imageDirectory = contentDirectory.appendingPathComponent("images")
|
self.imageDirectory = contentDirectory.appendingPathComponent("images")
|
||||||
|
self.sessionDirectory = contentDirectory.appendingPathComponent("session")
|
||||||
self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel")
|
self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel")
|
||||||
self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails")
|
self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails")
|
||||||
df.dateFormat = "yy-MM-dd-HH-mm-ss"
|
df.dateFormat = "yy-MM-dd-HH-mm-ss"
|
||||||
@ -55,8 +59,8 @@ final class ClassifierCreator {
|
|||||||
print("[INFO] Classes: \(classes.count)")
|
print("[INFO] Classes: \(classes.count)")
|
||||||
print("[INFO] Iterations: \(configuration.trainingIterations)")
|
print("[INFO] Iterations: \(configuration.trainingIterations)")
|
||||||
|
|
||||||
try trainModel()
|
try await trainAndSaveModel()
|
||||||
try await upload(version: newVersion)
|
try await uploadModel(version: newVersion)
|
||||||
try await upload(classes: classes, lastUpdate: imagesSnapshotDate)
|
try await upload(classes: classes, lastUpdate: imagesSnapshotDate)
|
||||||
await createThumbnails(changed: changedMainImages)
|
await createThumbnails(changed: changedMainImages)
|
||||||
print("[INFO] Done")
|
print("[INFO] Done")
|
||||||
@ -254,18 +258,80 @@ final class ClassifierCreator {
|
|||||||
|
|
||||||
// MARK: Step 4: Train classifier
|
// MARK: Step 4: Train classifier
|
||||||
|
|
||||||
func trainModel() throws {
|
func trainAndSaveModel() async throws {
|
||||||
var params = MLImageClassifier.ModelParameters(augmentation: [])
|
let model = try await trainModelAsync()
|
||||||
params.maxIterations = configuration.trainingIterations
|
//let model = try trainModelSync()
|
||||||
let model: MLImageClassifier
|
try save(model: model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func trainModelAsync() async throws -> MLImageClassifier {
|
||||||
|
let params = MLImageClassifier.ModelParameters(
|
||||||
|
maxIterations: configuration.trainingIterations,
|
||||||
|
augmentation: [])
|
||||||
|
|
||||||
|
let sessionParameters = MLTrainingSessionParameters(
|
||||||
|
sessionDirectory: sessionDirectory,
|
||||||
|
reportInterval: 10,
|
||||||
|
checkpointInterval: 100,
|
||||||
|
iterations: 1000)
|
||||||
|
|
||||||
|
var subscriptions = [AnyCancellable]()
|
||||||
|
|
||||||
|
let job: MLJob<MLImageClassifier>
|
||||||
do {
|
do {
|
||||||
model = try MLImageClassifier(
|
job = try MLImageClassifier.train(
|
||||||
|
trainingData: .labeledDirectories(at: imageDirectory),
|
||||||
|
parameters: params,
|
||||||
|
sessionParameters: sessionParameters)
|
||||||
|
} catch {
|
||||||
|
throw TrainingError.failedToCreateClassifier
|
||||||
|
}
|
||||||
|
|
||||||
|
job.progress
|
||||||
|
.publisher(for: \.fractionCompleted)
|
||||||
|
.sink { completed in
|
||||||
|
print(String(format: " %.1f %% completed", completed * 100), terminator: "\r")
|
||||||
|
fflush(stdout)
|
||||||
|
|
||||||
|
//guard let progress = MLProgress(progress: job.progress) else {
|
||||||
|
// return
|
||||||
|
//}
|
||||||
|
//if let styleLoss = progress.metrics[.styleLoss] { _ = styleLoss }
|
||||||
|
//if let contentLoss = progress.metrics[.contentLoss] { _ = contentLoss }
|
||||||
|
|
||||||
|
}
|
||||||
|
.store(in: &subscriptions)
|
||||||
|
|
||||||
|
return try await withCheckedThrowingContinuation { continuation in
|
||||||
|
// Register a sink to receive the resulting model.
|
||||||
|
job.result.sink { result in
|
||||||
|
print("[ERROR] \(result)")
|
||||||
|
continuation.resume(throwing: TrainingError.failedToCreateClassifier)
|
||||||
|
} receiveValue: { model in
|
||||||
|
// Use model
|
||||||
|
print("[INFO] Created model")
|
||||||
|
continuation.resume(returning: model)
|
||||||
|
}
|
||||||
|
.store(in: &subscriptions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func trainModelSync() throws -> MLImageClassifier {
|
||||||
|
let params = MLImageClassifier.ModelParameters(
|
||||||
|
maxIterations: configuration.trainingIterations,
|
||||||
|
augmentation: [])
|
||||||
|
|
||||||
|
do {
|
||||||
|
return try MLImageClassifier(
|
||||||
trainingData: .labeledDirectories(at: imageDirectory),
|
trainingData: .labeledDirectories(at: imageDirectory),
|
||||||
parameters: params)
|
parameters: params)
|
||||||
} catch {
|
} catch {
|
||||||
print("[ERROR] Failed to create classifier: \(error)")
|
print("[ERROR] Failed to create classifier: \(error)")
|
||||||
throw TrainingError.failedToCreateClassifier(error)
|
throw TrainingError.failedToCreateClassifier
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func save(model: MLImageClassifier) throws {
|
||||||
|
|
||||||
print("[INFO] Saving classifier...")
|
print("[INFO] Saving classifier...")
|
||||||
do {
|
do {
|
||||||
@ -278,7 +344,7 @@ final class ClassifierCreator {
|
|||||||
|
|
||||||
// MARK: Step 5: Upload classifier
|
// MARK: Step 5: Upload classifier
|
||||||
|
|
||||||
func upload(version: Int) async throws {
|
func uploadModel(version: Int) async throws {
|
||||||
print("[INFO] Uploading classifier...")
|
print("[INFO] Uploading classifier...")
|
||||||
let modelData: Data
|
let modelData: Data
|
||||||
do {
|
do {
|
||||||
|
@ -17,7 +17,7 @@ enum TrainingError: Error {
|
|||||||
case failedToGetClassifierVersion
|
case failedToGetClassifierVersion
|
||||||
case invalidClassifierVersion(String)
|
case invalidClassifierVersion(String)
|
||||||
|
|
||||||
case failedToCreateClassifier(Error)
|
case failedToCreateClassifier
|
||||||
case failedToWriteClassifier(Error)
|
case failedToWriteClassifier(Error)
|
||||||
|
|
||||||
case failedToGetChangedImagesList
|
case failedToGetChangedImagesList
|
||||||
|
Loading…
Reference in New Issue
Block a user