67 lines
2.8 KiB
Swift
67 lines
2.8 KiB
Swift
import ArgumentParser
|
|
import Foundation
|
|
|
|
@main
|
|
struct CapTrain: AsyncParsableCommand {
|
|
|
|
private static let defaultIterations = 10
|
|
|
|
private static let defaultMinimumImageCount = 10
|
|
|
|
@Flag(name: .shortAndLong, help: "Resume the previous training session (default: false)")
|
|
var resume: Bool = false
|
|
|
|
@Option(name: .shortAndLong, help: "The path to the configuration file. The file must be a json object containing command line arguments. Command line options take precedence over configuration file options")
|
|
var configuration: String?
|
|
|
|
@Option(name: .shortAndLong, help: "The number of iterations to train (default: 10)")
|
|
var iterations: Int?
|
|
|
|
@Option(name: .shortAndLong, help: "The minimum number of images for a cap to be included in training (default: 10)")
|
|
var minimumImageCount: Int?
|
|
|
|
@Option(name: .shortAndLong, help: "The url of the caps server to retrieve images and upload the classifier")
|
|
var server: String?
|
|
|
|
@Option(name: .shortAndLong, help: "The authentication token for the server")
|
|
var authentication: String?
|
|
|
|
@Option(name: .shortAndLong, help: "The path to the folder where the content (images, classifier, thumbnails) is stored")
|
|
var folder: String?
|
|
|
|
@Flag(name: .long, help: "Only upload the previous classifier")
|
|
var skipTraining: Bool = false
|
|
|
|
func run() async throws {
|
|
let configurationFile = try configurationFile()
|
|
let iterations = iterations ?? configurationFile?.iterations ?? CapTrain.defaultIterations
|
|
let minimumCount = minimumImageCount ?? configurationFile?.minimumImagesPerCap ?? CapTrain.defaultMinimumImageCount
|
|
guard let contentFolder = folder ?? configurationFile?.folder else {
|
|
throw TrainingError.missingArguments("folder")
|
|
}
|
|
guard let serverPath = server ?? configurationFile?.server else {
|
|
throw TrainingError.missingArguments("server")
|
|
}
|
|
guard let authentication = authentication ?? configurationFile?.authentication else {
|
|
throw TrainingError.missingArguments("authentication")
|
|
}
|
|
let configuration = Configuration(
|
|
contentFolder: contentFolder,
|
|
trainingIterations: iterations,
|
|
serverPath: serverPath,
|
|
|
|
authenticationToken: authentication,
|
|
minimumImagesPerCap: minimumCount)
|
|
let creator = try ClassifierCreator(configuration: configuration, resume: resume)
|
|
try await creator.run(skipTraining: skipTraining)
|
|
}
|
|
|
|
private func configurationFile() throws -> ConfigurationFile? {
|
|
guard let configuration else {
|
|
return nil
|
|
}
|
|
let configurationFileUrl = URL(fileURLWithPath: configuration)
|
|
return try ConfigurationFile(at: configurationFileUrl)
|
|
}
|
|
}
|