Caps-Train/Sources/CapTrain.swift
2024-10-28 19:06:45 +01:00

67 lines
2.8 KiB
Swift

import ArgumentParser
import Foundation
@main
struct CapTrain: AsyncParsableCommand {
private static let defaultIterations = 10
private static let defaultMinimumImageCount = 10
@Flag(name: .shortAndLong, help: "Resume the previous training session (default: false)")
var resume: Bool = false
@Option(name: .shortAndLong, help: "The path to the configuration file. The file must be a json object containing command line arguments. Command line options take precedence over configuration file options")
var configuration: String?
@Option(name: .shortAndLong, help: "The number of iterations to train (default: 10)")
var iterations: Int?
@Option(name: .shortAndLong, help: "The minimum number of images for a cap to be included in training (default: 10)")
var minimumImageCount: Int?
@Option(name: .shortAndLong, help: "The url of the caps server to retrieve images and upload the classifier")
var server: String?
@Option(name: .shortAndLong, help: "The authentication token for the server")
var authentication: String?
@Option(name: .shortAndLong, help: "The path to the folder where the content (images, classifier, thumbnails) is stored")
var folder: String?
@Flag(name: .long, help: "Only upload the previous classifier")
var skipTraining: Bool = false
func run() async throws {
let configurationFile = try configurationFile()
let iterations = iterations ?? configurationFile?.iterations ?? CapTrain.defaultIterations
let minimumCount = minimumImageCount ?? configurationFile?.minimumImagesPerCap ?? CapTrain.defaultMinimumImageCount
guard let contentFolder = folder ?? configurationFile?.folder else {
throw TrainingError.missingArguments("folder")
}
guard let serverPath = server ?? configurationFile?.server else {
throw TrainingError.missingArguments("server")
}
guard let authentication = authentication ?? configurationFile?.authentication else {
throw TrainingError.missingArguments("authentication")
}
let configuration = Configuration(
contentFolder: contentFolder,
trainingIterations: iterations,
serverPath: serverPath,
authenticationToken: authentication,
minimumImagesPerCap: minimumCount)
let creator = try ClassifierCreator(configuration: configuration, resume: resume)
try await creator.run(skipTraining: skipTraining)
}
private func configurationFile() throws -> ConfigurationFile? {
guard let configuration else {
return nil
}
let configurationFileUrl = URL(fileURLWithPath: configuration)
return try ConfigurationFile(at: configurationFileUrl)
}
}