diff --git a/Sources/CapTrain.swift b/Sources/CapTrain.swift index ff40327..83ea092 100644 --- a/Sources/CapTrain.swift +++ b/Sources/CapTrain.swift @@ -4,6 +4,9 @@ import Foundation @main struct CapTrain: AsyncParsableCommand { + @Flag(name: .shortAndLong, help: "Resume the previous training session (default: false)") + var resume: Bool = false + @Argument(help: "The path to the configuration file") var configPath: String? @@ -35,7 +38,7 @@ struct CapTrain: AsyncParsableCommand { serverPath: serverPath, authenticationToken: authenticationToken) - let creator = try ClassifierCreator(configuration: configuration) + let creator = try ClassifierCreator(configuration: configuration, resume: resume) try await creator.run() } diff --git a/Sources/ClassifierCreator.swift b/Sources/ClassifierCreator.swift index e95f907..7be9e4f 100644 --- a/Sources/ClassifierCreator.swift +++ b/Sources/ClassifierCreator.swift @@ -29,7 +29,7 @@ final class ClassifierCreator { // MARK: Step 1: Load configuration - init(configuration: Configuration) throws { + init(configuration: Configuration, resume: Bool) throws { self.configuration = configuration self.server = try configuration.serverUrl() let contentDirectory = URL(fileURLWithPath: configuration.contentFolder) @@ -38,6 +38,21 @@ final class ClassifierCreator { self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel") self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails") df.dateFormat = "yy-MM-dd-HH-mm-ss" + if !resume { + try removeSessionData() + } + } + + private func removeSessionData() throws { + guard FileManager.default.fileExists(atPath: sessionDirectory.path) else { + return + } + do { + try FileManager.default.removeItem(at: sessionDirectory) + print(info: "Removed training session data") + } catch { + throw TrainingError.failedToRemoveSessionFolder(error) + } } // MARK: Main function @@ -71,6 +86,7 @@ final class ClassifierCreator { try await uploadModel(version: newVersion) try await upload(classes: classes, lastUpdate: imagesSnapshotDate) try await createThumbnails(changed: changedMainImages) + try removeSessionData() print(info: "Done") } diff --git a/Sources/TrainingError.swift b/Sources/TrainingError.swift index 0cf01fe..5efaea0 100644 --- a/Sources/TrainingError.swift +++ b/Sources/TrainingError.swift @@ -11,6 +11,7 @@ enum TrainingError: Error { case configurationFileDecodingFailed(URL, Error) case invalidServerPath(String) case mainImageFolderNotCreated(Error) + case failedToRemoveSessionFolder(Error) case failedToGetCapDatabase(Error) case failedToDecodeCapDatabase(Error) @@ -65,6 +66,8 @@ extension TrainingError: CustomStringConvertible { return "Configuration: Invalid server path \(path)" case .mainImageFolderNotCreated(let error): return "Failed to create main image folder: \(error)" + case .failedToRemoveSessionFolder(let error): + return "Failed to remove session folder: \(error)" case .failedToGetCapDatabase(let error): return "Failed to get cap database from server: \(error)"