diff --git a/Sources/CapTrain.swift b/Sources/CapTrain.swift index a6bba5d..ff40327 100644 --- a/Sources/CapTrain.swift +++ b/Sources/CapTrain.swift @@ -5,14 +5,45 @@ import Foundation struct CapTrain: AsyncParsableCommand { @Argument(help: "The path to the configuration file") - var configPath: String + var configPath: String? + + @Option(name: .shortAndLong, help: "The number of iterations to train") + var iterations: Int? + + @Option(name: .shortAndLong, help: "The url of the caps server") + var serverPath: String? + + @Option(name: .shortAndLong, help: "The authentication token for the server") + var authentication: String? + + @Option(name: .shortAndLong, help: "The folder where the content (images, classifier, thumbnails) is stored") + var folder: String? func run() async throws { - let configurationFileUrl = URL(fileURLWithPath: configPath) - - let configuration = try Configuration(at: configurationFileUrl) - let creator = try ClassifierCreator(configuration: configuration) + let configurationFile = try configurationFile() + guard let contentFolder = folder ?? configurationFile?.contentFolder, + let trainingIterations = iterations ?? configurationFile?.trainingIterations, + let serverPath = serverPath ?? configurationFile?.serverPath, + let authenticationToken = authentication ?? configurationFile?.authenticationToken + else { + throw TrainingError.missingArguments + } + let configuration = Configuration( + contentFolder: contentFolder, + trainingIterations: trainingIterations, + serverPath: serverPath, + authenticationToken: authenticationToken) + + let creator = try ClassifierCreator(configuration: configuration) try await creator.run() } + + private func configurationFile() throws -> ConfigurationFile? { + guard let configPath else { + return nil + } + let configurationFileUrl = URL(fileURLWithPath: configPath) + return try ConfigurationFile(at: configurationFileUrl) + } } diff --git a/Sources/Configuration.swift b/Sources/Configuration.swift index 656221d..fffb3e0 100644 --- a/Sources/Configuration.swift +++ b/Sources/Configuration.swift @@ -1,34 +1,17 @@ import Foundation -struct Configuration: Codable { +struct Configuration { let contentFolder: String - + let trainingIterations: Int let serverPath: String let authenticationToken: String - - init(at url: URL) throws { - guard FileManager.default.fileExists(atPath: url.path) else { - print("[ERROR] No configuration at \(url.absoluteURL.path)") - throw TrainingError.configurationFileMissing - } - let data: Data - do { - data = try Data(contentsOf: url) - } catch { - print("[ERROR] Failed to load configuration data at \(url.absoluteURL.path): \(error)") - throw TrainingError.configurationFileUnreadable - } - do { - self = try JSONDecoder().decode(Configuration.self, from: data) - } catch { - print("[ERROR] Failed to decode configuration at \(url.absoluteURL.path): \(error)") - throw TrainingError.configurationFileDecodingFailed - } - } +} + +extension Configuration { func serverUrl() throws -> URL { guard let serverUrl = URL(string: serverPath) else { diff --git a/Sources/ConfigurationFile.swift b/Sources/ConfigurationFile.swift new file mode 100644 index 0000000..46f1c8b --- /dev/null +++ b/Sources/ConfigurationFile.swift @@ -0,0 +1,39 @@ +import Foundation + +struct ConfigurationFile { + + let contentFolder: String? + + let trainingIterations: Int? + + let serverPath: String? + + let authenticationToken: String? +} + +extension ConfigurationFile: Decodable { + +} + +extension ConfigurationFile { + + init(at url: URL) throws { + guard FileManager.default.fileExists(atPath: url.path) else { + print("[ERROR] No configuration at \(url.absoluteURL.path)") + throw TrainingError.configurationFileMissing + } + let data: Data + do { + data = try Data(contentsOf: url) + } catch { + print("[ERROR] Failed to load configuration data at \(url.absoluteURL.path): \(error)") + throw TrainingError.configurationFileUnreadable + } + do { + self = try JSONDecoder().decode(ConfigurationFile.self, from: data) + } catch { + print("[ERROR] Failed to decode configuration at \(url.absoluteURL.path): \(error)") + throw TrainingError.configurationFileDecodingFailed + } + } +} diff --git a/Sources/TrainingError.swift b/Sources/TrainingError.swift index b0648a4..e73b03b 100644 --- a/Sources/TrainingError.swift +++ b/Sources/TrainingError.swift @@ -2,6 +2,8 @@ import Foundation enum TrainingError: Error { + case missingArguments + case configurationFileMissing case configurationFileUnreadable case configurationFileDecodingFailed