Add resume feature
This commit is contained in:
parent
3451bb7254
commit
6aa1026478
@ -4,6 +4,9 @@ import Foundation
|
||||
@main
|
||||
struct CapTrain: AsyncParsableCommand {
|
||||
|
||||
@Flag(name: .shortAndLong, help: "Resume the previous training session (default: false)")
|
||||
var resume: Bool = false
|
||||
|
||||
@Argument(help: "The path to the configuration file")
|
||||
var configPath: String?
|
||||
|
||||
@ -35,7 +38,7 @@ struct CapTrain: AsyncParsableCommand {
|
||||
serverPath: serverPath,
|
||||
authenticationToken: authenticationToken)
|
||||
|
||||
let creator = try ClassifierCreator(configuration: configuration)
|
||||
let creator = try ClassifierCreator(configuration: configuration, resume: resume)
|
||||
try await creator.run()
|
||||
}
|
||||
|
||||
|
@ -29,7 +29,7 @@ final class ClassifierCreator {
|
||||
|
||||
// MARK: Step 1: Load configuration
|
||||
|
||||
init(configuration: Configuration) throws {
|
||||
init(configuration: Configuration, resume: Bool) throws {
|
||||
self.configuration = configuration
|
||||
self.server = try configuration.serverUrl()
|
||||
let contentDirectory = URL(fileURLWithPath: configuration.contentFolder)
|
||||
@ -38,6 +38,21 @@ final class ClassifierCreator {
|
||||
self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel")
|
||||
self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails")
|
||||
df.dateFormat = "yy-MM-dd-HH-mm-ss"
|
||||
if !resume {
|
||||
try removeSessionData()
|
||||
}
|
||||
}
|
||||
|
||||
private func removeSessionData() throws {
|
||||
guard FileManager.default.fileExists(atPath: sessionDirectory.path) else {
|
||||
return
|
||||
}
|
||||
do {
|
||||
try FileManager.default.removeItem(at: sessionDirectory)
|
||||
print(info: "Removed training session data")
|
||||
} catch {
|
||||
throw TrainingError.failedToRemoveSessionFolder(error)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: Main function
|
||||
@ -71,6 +86,7 @@ final class ClassifierCreator {
|
||||
try await uploadModel(version: newVersion)
|
||||
try await upload(classes: classes, lastUpdate: imagesSnapshotDate)
|
||||
try await createThumbnails(changed: changedMainImages)
|
||||
try removeSessionData()
|
||||
print(info: "Done")
|
||||
}
|
||||
|
||||
|
@ -11,6 +11,7 @@ enum TrainingError: Error {
|
||||
case configurationFileDecodingFailed(URL, Error)
|
||||
case invalidServerPath(String)
|
||||
case mainImageFolderNotCreated(Error)
|
||||
case failedToRemoveSessionFolder(Error)
|
||||
|
||||
case failedToGetCapDatabase(Error)
|
||||
case failedToDecodeCapDatabase(Error)
|
||||
@ -65,6 +66,8 @@ extension TrainingError: CustomStringConvertible {
|
||||
return "Configuration: Invalid server path \(path)"
|
||||
case .mainImageFolderNotCreated(let error):
|
||||
return "Failed to create main image folder: \(error)"
|
||||
case .failedToRemoveSessionFolder(let error):
|
||||
return "Failed to remove session folder: \(error)"
|
||||
|
||||
case .failedToGetCapDatabase(let error):
|
||||
return "Failed to get cap database from server: \(error)"
|
||||
|
Loading…
Reference in New Issue
Block a user