Add resume feature

This commit is contained in:
Christoph Hagen 2023-10-24 11:16:22 +02:00
parent 3451bb7254
commit 6aa1026478
3 changed files with 24 additions and 2 deletions

View File

@ -4,6 +4,9 @@ import Foundation
@main @main
struct CapTrain: AsyncParsableCommand { struct CapTrain: AsyncParsableCommand {
@Flag(name: .shortAndLong, help: "Resume the previous training session (default: false)")
var resume: Bool = false
@Argument(help: "The path to the configuration file") @Argument(help: "The path to the configuration file")
var configPath: String? var configPath: String?
@ -35,7 +38,7 @@ struct CapTrain: AsyncParsableCommand {
serverPath: serverPath, serverPath: serverPath,
authenticationToken: authenticationToken) authenticationToken: authenticationToken)
let creator = try ClassifierCreator(configuration: configuration) let creator = try ClassifierCreator(configuration: configuration, resume: resume)
try await creator.run() try await creator.run()
} }

View File

@ -29,7 +29,7 @@ final class ClassifierCreator {
// MARK: Step 1: Load configuration // MARK: Step 1: Load configuration
init(configuration: Configuration) throws { init(configuration: Configuration, resume: Bool) throws {
self.configuration = configuration self.configuration = configuration
self.server = try configuration.serverUrl() self.server = try configuration.serverUrl()
let contentDirectory = URL(fileURLWithPath: configuration.contentFolder) let contentDirectory = URL(fileURLWithPath: configuration.contentFolder)
@ -38,6 +38,21 @@ final class ClassifierCreator {
self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel") self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel")
self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails") self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails")
df.dateFormat = "yy-MM-dd-HH-mm-ss" df.dateFormat = "yy-MM-dd-HH-mm-ss"
if !resume {
try removeSessionData()
}
}
private func removeSessionData() throws {
guard FileManager.default.fileExists(atPath: sessionDirectory.path) else {
return
}
do {
try FileManager.default.removeItem(at: sessionDirectory)
print(info: "Removed training session data")
} catch {
throw TrainingError.failedToRemoveSessionFolder(error)
}
} }
// MARK: Main function // MARK: Main function
@ -71,6 +86,7 @@ final class ClassifierCreator {
try await uploadModel(version: newVersion) try await uploadModel(version: newVersion)
try await upload(classes: classes, lastUpdate: imagesSnapshotDate) try await upload(classes: classes, lastUpdate: imagesSnapshotDate)
try await createThumbnails(changed: changedMainImages) try await createThumbnails(changed: changedMainImages)
try removeSessionData()
print(info: "Done") print(info: "Done")
} }

View File

@ -11,6 +11,7 @@ enum TrainingError: Error {
case configurationFileDecodingFailed(URL, Error) case configurationFileDecodingFailed(URL, Error)
case invalidServerPath(String) case invalidServerPath(String)
case mainImageFolderNotCreated(Error) case mainImageFolderNotCreated(Error)
case failedToRemoveSessionFolder(Error)
case failedToGetCapDatabase(Error) case failedToGetCapDatabase(Error)
case failedToDecodeCapDatabase(Error) case failedToDecodeCapDatabase(Error)
@ -65,6 +66,8 @@ extension TrainingError: CustomStringConvertible {
return "Configuration: Invalid server path \(path)" return "Configuration: Invalid server path \(path)"
case .mainImageFolderNotCreated(let error): case .mainImageFolderNotCreated(let error):
return "Failed to create main image folder: \(error)" return "Failed to create main image folder: \(error)"
case .failedToRemoveSessionFolder(let error):
return "Failed to remove session folder: \(error)"
case .failedToGetCapDatabase(let error): case .failedToGetCapDatabase(let error):
return "Failed to get cap database from server: \(error)" return "Failed to get cap database from server: \(error)"