Compare commits
2 Commits
796c18c7e0
...
93af0cb76d
Author | SHA1 | Date | |
---|---|---|---|
|
93af0cb76d | ||
|
58b1dcb277 |
@ -24,6 +24,9 @@ struct CapTrain: AsyncParsableCommand {
|
|||||||
@Option(name: .shortAndLong, help: "The path to the folder where the content (images, classifier, thumbnails) is stored")
|
@Option(name: .shortAndLong, help: "The path to the folder where the content (images, classifier, thumbnails) is stored")
|
||||||
var folder: String?
|
var folder: String?
|
||||||
|
|
||||||
|
@Flag(name: .long, help: "Only upload the previous classifier")
|
||||||
|
var skipTraining: Bool = false
|
||||||
|
|
||||||
func run() async throws {
|
func run() async throws {
|
||||||
let configurationFile = try configurationFile()
|
let configurationFile = try configurationFile()
|
||||||
let iterations = iterations ?? configurationFile?.iterations ?? CapTrain.defaultIterations
|
let iterations = iterations ?? configurationFile?.iterations ?? CapTrain.defaultIterations
|
||||||
@ -36,7 +39,6 @@ struct CapTrain: AsyncParsableCommand {
|
|||||||
guard let authentication = authentication ?? configurationFile?.authentication else {
|
guard let authentication = authentication ?? configurationFile?.authentication else {
|
||||||
throw TrainingError.missingArguments("authentication")
|
throw TrainingError.missingArguments("authentication")
|
||||||
}
|
}
|
||||||
|
|
||||||
let configuration = Configuration(
|
let configuration = Configuration(
|
||||||
contentFolder: contentFolder,
|
contentFolder: contentFolder,
|
||||||
trainingIterations: iterations,
|
trainingIterations: iterations,
|
||||||
@ -44,7 +46,7 @@ struct CapTrain: AsyncParsableCommand {
|
|||||||
authenticationToken: authentication)
|
authenticationToken: authentication)
|
||||||
|
|
||||||
let creator = try ClassifierCreator(configuration: configuration, resume: resume)
|
let creator = try ClassifierCreator(configuration: configuration, resume: resume)
|
||||||
try await creator.run()
|
try await creator.run(skipTraining: skipTraining)
|
||||||
}
|
}
|
||||||
|
|
||||||
private func configurationFile() throws -> ConfigurationFile? {
|
private func configurationFile() throws -> ConfigurationFile? {
|
||||||
|
@ -23,6 +23,8 @@ final class ClassifierCreator {
|
|||||||
|
|
||||||
let classifierUrl: URL
|
let classifierUrl: URL
|
||||||
|
|
||||||
|
let urlSession = URLSession(configuration: .ephemeral)
|
||||||
|
|
||||||
let df = DateFormatter()
|
let df = DateFormatter()
|
||||||
|
|
||||||
private func print(info: String) {
|
private func print(info: String) {
|
||||||
@ -59,7 +61,7 @@ final class ClassifierCreator {
|
|||||||
|
|
||||||
// MARK: Main function
|
// MARK: Main function
|
||||||
|
|
||||||
func run() async throws {
|
func run(skipTraining: Bool) async throws {
|
||||||
let imagesSnapshotDate = Date()
|
let imagesSnapshotDate = Date()
|
||||||
let (classes, changedImageCount, changedMainImages) = try await loadImages()
|
let (classes, changedImageCount, changedMainImages) = try await loadImages()
|
||||||
|
|
||||||
@ -78,13 +80,16 @@ final class ClassifierCreator {
|
|||||||
let classifierVersion = try await getClassifierVersion()
|
let classifierVersion = try await getClassifierVersion()
|
||||||
let newVersion = classifierVersion + 1
|
let newVersion = classifierVersion + 1
|
||||||
|
|
||||||
|
print(info: "Server: \(server.path)")
|
||||||
print(info: "Image directory: \(imageDirectory.absoluteURL.path)")
|
print(info: "Image directory: \(imageDirectory.absoluteURL.path)")
|
||||||
print(info: "Model path: \(classifierUrl.path)")
|
print(info: "Model path: \(classifierUrl.path)")
|
||||||
print(info: "Version: \(newVersion)")
|
print(info: "Version: \(newVersion)")
|
||||||
print(info: "Classes: \(classes.count)")
|
print(info: "Classes: \(classes.count)")
|
||||||
print(info: "Iterations: \(configuration.trainingIterations)")
|
print(info: "Iterations: \(configuration.trainingIterations)")
|
||||||
|
|
||||||
|
if !skipTraining {
|
||||||
try await trainAndSaveModel()
|
try await trainAndSaveModel()
|
||||||
|
}
|
||||||
try await uploadModel(version: newVersion)
|
try await uploadModel(version: newVersion)
|
||||||
try await upload(classes: classes, lastUpdate: imagesSnapshotDate)
|
try await upload(classes: classes, lastUpdate: imagesSnapshotDate)
|
||||||
try await createThumbnails(changed: changedMainImages)
|
try await createThumbnails(changed: changedMainImages)
|
||||||
@ -198,7 +203,7 @@ final class ClassifierCreator {
|
|||||||
let url = imageUrl(base: server.appendingPathComponent("images"), image: image)
|
let url = imageUrl(base: server.appendingPathComponent("images"), image: image)
|
||||||
let tempFile: URL, response: URLResponse
|
let tempFile: URL, response: URLResponse
|
||||||
do {
|
do {
|
||||||
(tempFile, response) = try await URLSession.shared.download(from: url)
|
(tempFile, response) = try await urlSession.download(from: url)
|
||||||
} catch {
|
} catch {
|
||||||
throw TrainingError.failedToLoadImage(image, error)
|
throw TrainingError.failedToLoadImage(image, error)
|
||||||
}
|
}
|
||||||
@ -520,7 +525,7 @@ final class ClassifierCreator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private func perform(_ request: URLRequest) async throws -> Data {
|
private func perform(_ request: URLRequest) async throws -> Data {
|
||||||
let (data, response) = try await URLSession.shared.data(for: request)
|
let (data, response) = try await urlSession.data(for: request)
|
||||||
let code = (response as! HTTPURLResponse).statusCode
|
let code = (response as! HTTPURLResponse).statusCode
|
||||||
guard code == 200 else {
|
guard code == 200 else {
|
||||||
throw TrainingError.invalidResponse(request.url!, code)
|
throw TrainingError.invalidResponse(request.url!, code)
|
||||||
|
Loading…
Reference in New Issue
Block a user