Allow skipping of training
This commit is contained in:
parent
796c18c7e0
commit
58b1dcb277
@ -24,6 +24,9 @@ struct CapTrain: AsyncParsableCommand {
|
||||
@Option(name: .shortAndLong, help: "The path to the folder where the content (images, classifier, thumbnails) is stored")
|
||||
var folder: String?
|
||||
|
||||
@Flag(name: .long, help: "Only upload the previous classifier")
|
||||
var skipTraining: Bool = false
|
||||
|
||||
func run() async throws {
|
||||
let configurationFile = try configurationFile()
|
||||
let iterations = iterations ?? configurationFile?.iterations ?? CapTrain.defaultIterations
|
||||
@ -36,7 +39,6 @@ struct CapTrain: AsyncParsableCommand {
|
||||
guard let authentication = authentication ?? configurationFile?.authentication else {
|
||||
throw TrainingError.missingArguments("authentication")
|
||||
}
|
||||
|
||||
let configuration = Configuration(
|
||||
contentFolder: contentFolder,
|
||||
trainingIterations: iterations,
|
||||
@ -44,7 +46,7 @@ struct CapTrain: AsyncParsableCommand {
|
||||
authenticationToken: authentication)
|
||||
|
||||
let creator = try ClassifierCreator(configuration: configuration, resume: resume)
|
||||
try await creator.run()
|
||||
try await creator.run(skipTraining: skipTraining)
|
||||
}
|
||||
|
||||
private func configurationFile() throws -> ConfigurationFile? {
|
||||
|
@ -59,7 +59,7 @@ final class ClassifierCreator {
|
||||
|
||||
// MARK: Main function
|
||||
|
||||
func run() async throws {
|
||||
func run(skipTraining: Bool) async throws {
|
||||
let imagesSnapshotDate = Date()
|
||||
let (classes, changedImageCount, changedMainImages) = try await loadImages()
|
||||
|
||||
@ -78,13 +78,16 @@ final class ClassifierCreator {
|
||||
let classifierVersion = try await getClassifierVersion()
|
||||
let newVersion = classifierVersion + 1
|
||||
|
||||
print(info: "Server: \(server.path)")
|
||||
print(info: "Image directory: \(imageDirectory.absoluteURL.path)")
|
||||
print(info: "Model path: \(classifierUrl.path)")
|
||||
print(info: "Version: \(newVersion)")
|
||||
print(info: "Classes: \(classes.count)")
|
||||
print(info: "Iterations: \(configuration.trainingIterations)")
|
||||
|
||||
if !skipTraining {
|
||||
try await trainAndSaveModel()
|
||||
}
|
||||
try await uploadModel(version: newVersion)
|
||||
try await upload(classes: classes, lastUpdate: imagesSnapshotDate)
|
||||
try await createThumbnails(changed: changedMainImages)
|
||||
|
Loading…
Reference in New Issue
Block a user