Compare commits

...

2 Commits

Author SHA1 Message Date
Christoph Hagen
93af0cb76d Don't cache web requests 2024-06-12 16:14:05 +02:00
Christoph Hagen
58b1dcb277 Allow skipping of training 2024-06-12 16:13:42 +02:00
2 changed files with 13 additions and 6 deletions

View File

@ -24,6 +24,9 @@ struct CapTrain: AsyncParsableCommand {
@Option(name: .shortAndLong, help: "The path to the folder where the content (images, classifier, thumbnails) is stored") @Option(name: .shortAndLong, help: "The path to the folder where the content (images, classifier, thumbnails) is stored")
var folder: String? var folder: String?
@Flag(name: .long, help: "Only upload the previous classifier")
var skipTraining: Bool = false
func run() async throws { func run() async throws {
let configurationFile = try configurationFile() let configurationFile = try configurationFile()
let iterations = iterations ?? configurationFile?.iterations ?? CapTrain.defaultIterations let iterations = iterations ?? configurationFile?.iterations ?? CapTrain.defaultIterations
@ -36,7 +39,6 @@ struct CapTrain: AsyncParsableCommand {
guard let authentication = authentication ?? configurationFile?.authentication else { guard let authentication = authentication ?? configurationFile?.authentication else {
throw TrainingError.missingArguments("authentication") throw TrainingError.missingArguments("authentication")
} }
let configuration = Configuration( let configuration = Configuration(
contentFolder: contentFolder, contentFolder: contentFolder,
trainingIterations: iterations, trainingIterations: iterations,
@ -44,7 +46,7 @@ struct CapTrain: AsyncParsableCommand {
authenticationToken: authentication) authenticationToken: authentication)
let creator = try ClassifierCreator(configuration: configuration, resume: resume) let creator = try ClassifierCreator(configuration: configuration, resume: resume)
try await creator.run() try await creator.run(skipTraining: skipTraining)
} }
private func configurationFile() throws -> ConfigurationFile? { private func configurationFile() throws -> ConfigurationFile? {

View File

@ -23,6 +23,8 @@ final class ClassifierCreator {
let classifierUrl: URL let classifierUrl: URL
let urlSession = URLSession(configuration: .ephemeral)
let df = DateFormatter() let df = DateFormatter()
private func print(info: String) { private func print(info: String) {
@ -59,7 +61,7 @@ final class ClassifierCreator {
// MARK: Main function // MARK: Main function
func run() async throws { func run(skipTraining: Bool) async throws {
let imagesSnapshotDate = Date() let imagesSnapshotDate = Date()
let (classes, changedImageCount, changedMainImages) = try await loadImages() let (classes, changedImageCount, changedMainImages) = try await loadImages()
@ -78,13 +80,16 @@ final class ClassifierCreator {
let classifierVersion = try await getClassifierVersion() let classifierVersion = try await getClassifierVersion()
let newVersion = classifierVersion + 1 let newVersion = classifierVersion + 1
print(info: "Server: \(server.path)")
print(info: "Image directory: \(imageDirectory.absoluteURL.path)") print(info: "Image directory: \(imageDirectory.absoluteURL.path)")
print(info: "Model path: \(classifierUrl.path)") print(info: "Model path: \(classifierUrl.path)")
print(info: "Version: \(newVersion)") print(info: "Version: \(newVersion)")
print(info: "Classes: \(classes.count)") print(info: "Classes: \(classes.count)")
print(info: "Iterations: \(configuration.trainingIterations)") print(info: "Iterations: \(configuration.trainingIterations)")
try await trainAndSaveModel() if !skipTraining {
try await trainAndSaveModel()
}
try await uploadModel(version: newVersion) try await uploadModel(version: newVersion)
try await upload(classes: classes, lastUpdate: imagesSnapshotDate) try await upload(classes: classes, lastUpdate: imagesSnapshotDate)
try await createThumbnails(changed: changedMainImages) try await createThumbnails(changed: changedMainImages)
@ -198,7 +203,7 @@ final class ClassifierCreator {
let url = imageUrl(base: server.appendingPathComponent("images"), image: image) let url = imageUrl(base: server.appendingPathComponent("images"), image: image)
let tempFile: URL, response: URLResponse let tempFile: URL, response: URLResponse
do { do {
(tempFile, response) = try await URLSession.shared.download(from: url) (tempFile, response) = try await urlSession.download(from: url)
} catch { } catch {
throw TrainingError.failedToLoadImage(image, error) throw TrainingError.failedToLoadImage(image, error)
} }
@ -520,7 +525,7 @@ final class ClassifierCreator {
} }
private func perform(_ request: URLRequest) async throws -> Data { private func perform(_ request: URLRequest) async throws -> Data {
let (data, response) = try await URLSession.shared.data(for: request) let (data, response) = try await urlSession.data(for: request)
let code = (response as! HTTPURLResponse).statusCode let code = (response as! HTTPURLResponse).statusCode
guard code == 200 else { guard code == 200 else {
throw TrainingError.invalidResponse(request.url!, code) throw TrainingError.invalidResponse(request.url!, code)