From 58b1dcb277228d5f0ac59e52b13cd5a3c099dec1 Mon Sep 17 00:00:00 2001 From: Christoph Hagen Date: Wed, 12 Jun 2024 16:13:42 +0200 Subject: [PATCH] Allow skipping of training --- Sources/CapTrain.swift | 6 ++++-- Sources/ClassifierCreator.swift | 7 +++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/Sources/CapTrain.swift b/Sources/CapTrain.swift index 97340b1..cb00b59 100644 --- a/Sources/CapTrain.swift +++ b/Sources/CapTrain.swift @@ -23,6 +23,9 @@ struct CapTrain: AsyncParsableCommand { @Option(name: .shortAndLong, help: "The path to the folder where the content (images, classifier, thumbnails) is stored") var folder: String? + + @Flag(name: .long, help: "Only upload the previous classifier") + var skipTraining: Bool = false func run() async throws { let configurationFile = try configurationFile() @@ -36,7 +39,6 @@ struct CapTrain: AsyncParsableCommand { guard let authentication = authentication ?? configurationFile?.authentication else { throw TrainingError.missingArguments("authentication") } - let configuration = Configuration( contentFolder: contentFolder, trainingIterations: iterations, @@ -44,7 +46,7 @@ struct CapTrain: AsyncParsableCommand { authenticationToken: authentication) let creator = try ClassifierCreator(configuration: configuration, resume: resume) - try await creator.run() + try await creator.run(skipTraining: skipTraining) } private func configurationFile() throws -> ConfigurationFile? { diff --git a/Sources/ClassifierCreator.swift b/Sources/ClassifierCreator.swift index ae70fd4..0060e76 100644 --- a/Sources/ClassifierCreator.swift +++ b/Sources/ClassifierCreator.swift @@ -59,7 +59,7 @@ final class ClassifierCreator { // MARK: Main function - func run() async throws { + func run(skipTraining: Bool) async throws { let imagesSnapshotDate = Date() let (classes, changedImageCount, changedMainImages) = try await loadImages() @@ -78,13 +78,16 @@ final class ClassifierCreator { let classifierVersion = try await getClassifierVersion() let newVersion = classifierVersion + 1 + print(info: "Server: \(server.path)") print(info: "Image directory: \(imageDirectory.absoluteURL.path)") print(info: "Model path: \(classifierUrl.path)") print(info: "Version: \(newVersion)") print(info: "Classes: \(classes.count)") print(info: "Iterations: \(configuration.trainingIterations)") - try await trainAndSaveModel() + if !skipTraining { + try await trainAndSaveModel() + } try await uploadModel(version: newVersion) try await upload(classes: classes, lastUpdate: imagesSnapshotDate) try await createThumbnails(changed: changedMainImages)