Add resume feature
This commit is contained in:
parent
3451bb7254
commit
6aa1026478
@ -4,6 +4,9 @@ import Foundation
|
|||||||
@main
|
@main
|
||||||
struct CapTrain: AsyncParsableCommand {
|
struct CapTrain: AsyncParsableCommand {
|
||||||
|
|
||||||
|
@Flag(name: .shortAndLong, help: "Resume the previous training session (default: false)")
|
||||||
|
var resume: Bool = false
|
||||||
|
|
||||||
@Argument(help: "The path to the configuration file")
|
@Argument(help: "The path to the configuration file")
|
||||||
var configPath: String?
|
var configPath: String?
|
||||||
|
|
||||||
@ -35,7 +38,7 @@ struct CapTrain: AsyncParsableCommand {
|
|||||||
serverPath: serverPath,
|
serverPath: serverPath,
|
||||||
authenticationToken: authenticationToken)
|
authenticationToken: authenticationToken)
|
||||||
|
|
||||||
let creator = try ClassifierCreator(configuration: configuration)
|
let creator = try ClassifierCreator(configuration: configuration, resume: resume)
|
||||||
try await creator.run()
|
try await creator.run()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ final class ClassifierCreator {
|
|||||||
|
|
||||||
// MARK: Step 1: Load configuration
|
// MARK: Step 1: Load configuration
|
||||||
|
|
||||||
init(configuration: Configuration) throws {
|
init(configuration: Configuration, resume: Bool) throws {
|
||||||
self.configuration = configuration
|
self.configuration = configuration
|
||||||
self.server = try configuration.serverUrl()
|
self.server = try configuration.serverUrl()
|
||||||
let contentDirectory = URL(fileURLWithPath: configuration.contentFolder)
|
let contentDirectory = URL(fileURLWithPath: configuration.contentFolder)
|
||||||
@ -38,6 +38,21 @@ final class ClassifierCreator {
|
|||||||
self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel")
|
self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel")
|
||||||
self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails")
|
self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails")
|
||||||
df.dateFormat = "yy-MM-dd-HH-mm-ss"
|
df.dateFormat = "yy-MM-dd-HH-mm-ss"
|
||||||
|
if !resume {
|
||||||
|
try removeSessionData()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func removeSessionData() throws {
|
||||||
|
guard FileManager.default.fileExists(atPath: sessionDirectory.path) else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
do {
|
||||||
|
try FileManager.default.removeItem(at: sessionDirectory)
|
||||||
|
print(info: "Removed training session data")
|
||||||
|
} catch {
|
||||||
|
throw TrainingError.failedToRemoveSessionFolder(error)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MARK: Main function
|
// MARK: Main function
|
||||||
@ -71,6 +86,7 @@ final class ClassifierCreator {
|
|||||||
try await uploadModel(version: newVersion)
|
try await uploadModel(version: newVersion)
|
||||||
try await upload(classes: classes, lastUpdate: imagesSnapshotDate)
|
try await upload(classes: classes, lastUpdate: imagesSnapshotDate)
|
||||||
try await createThumbnails(changed: changedMainImages)
|
try await createThumbnails(changed: changedMainImages)
|
||||||
|
try removeSessionData()
|
||||||
print(info: "Done")
|
print(info: "Done")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ enum TrainingError: Error {
|
|||||||
case configurationFileDecodingFailed(URL, Error)
|
case configurationFileDecodingFailed(URL, Error)
|
||||||
case invalidServerPath(String)
|
case invalidServerPath(String)
|
||||||
case mainImageFolderNotCreated(Error)
|
case mainImageFolderNotCreated(Error)
|
||||||
|
case failedToRemoveSessionFolder(Error)
|
||||||
|
|
||||||
case failedToGetCapDatabase(Error)
|
case failedToGetCapDatabase(Error)
|
||||||
case failedToDecodeCapDatabase(Error)
|
case failedToDecodeCapDatabase(Error)
|
||||||
@ -65,6 +66,8 @@ extension TrainingError: CustomStringConvertible {
|
|||||||
return "Configuration: Invalid server path \(path)"
|
return "Configuration: Invalid server path \(path)"
|
||||||
case .mainImageFolderNotCreated(let error):
|
case .mainImageFolderNotCreated(let error):
|
||||||
return "Failed to create main image folder: \(error)"
|
return "Failed to create main image folder: \(error)"
|
||||||
|
case .failedToRemoveSessionFolder(let error):
|
||||||
|
return "Failed to remove session folder: \(error)"
|
||||||
|
|
||||||
case .failedToGetCapDatabase(let error):
|
case .failedToGetCapDatabase(let error):
|
||||||
return "Failed to get cap database from server: \(error)"
|
return "Failed to get cap database from server: \(error)"
|
||||||
|
Loading…
Reference in New Issue
Block a user