From 9a50328b9349cf603151a78761ab29a40d3ef5a9 Mon Sep 17 00:00:00 2001 From: Christoph Hagen Date: Mon, 23 Oct 2023 12:28:35 +0200 Subject: [PATCH] Add code from training script --- Config/config_example.json | 6 + Package.swift | 4 +- Sources/Cap.swift | 27 ++ Sources/CapTrain.swift | 18 ++ Sources/Cap_Train.swift | 14 - Sources/ClassifierCreator.swift | 463 ++++++++++++++++++++++++++++++++ Sources/Configuration.swift | 40 +++ Sources/TrainingError.swift | 27 ++ 8 files changed, 582 insertions(+), 17 deletions(-) create mode 100644 Config/config_example.json create mode 100644 Sources/Cap.swift create mode 100644 Sources/CapTrain.swift delete mode 100644 Sources/Cap_Train.swift create mode 100644 Sources/ClassifierCreator.swift create mode 100644 Sources/Configuration.swift create mode 100644 Sources/TrainingError.swift diff --git a/Config/config_example.json b/Config/config_example.json new file mode 100644 index 0000000..af6a6fe --- /dev/null +++ b/Config/config_example.json @@ -0,0 +1,6 @@ +{ + "contentDirectory": "../Public", + "trainingIterations": 20, + "serverPath": "https://mydomain.com/caps", + "authenticationToken": "mysecretkey", +} \ No newline at end of file diff --git a/Package.swift b/Package.swift index 871d392..ba8c899 100644 --- a/Package.swift +++ b/Package.swift @@ -1,16 +1,14 @@ // swift-tools-version: 5.9 -// The swift-tools-version declares the minimum version of Swift required to build this package. import PackageDescription let package = Package( name: "Cap-Train", + platforms: [.macOS(.v12)], dependencies: [ .package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.2.0"), ], targets: [ - // Targets are the basic building blocks of a package, defining a module or a test suite. - // Targets can depend on other targets in this package and products from dependencies. .executableTarget( name: "Cap-Train", dependencies: [ diff --git a/Sources/Cap.swift b/Sources/Cap.swift new file mode 100644 index 0000000..3a436c3 --- /dev/null +++ b/Sources/Cap.swift @@ -0,0 +1,27 @@ +import Foundation + +struct Cap: Codable { + + let id: Int + + let count: Int + + enum CodingKeys: String, CodingKey { + case id = "i" + case count = "c" + } +} + +struct CapImage: Equatable { + + let cap: Int + + let image: Int +} + +extension CapImage: CustomStringConvertible { + + var description: String { + "image \(image) of cap \(cap)" + } +} diff --git a/Sources/CapTrain.swift b/Sources/CapTrain.swift new file mode 100644 index 0000000..a6bba5d --- /dev/null +++ b/Sources/CapTrain.swift @@ -0,0 +1,18 @@ +import ArgumentParser +import Foundation + +@main +struct CapTrain: AsyncParsableCommand { + + @Argument(help: "The path to the configuration file") + var configPath: String + + func run() async throws { + let configurationFileUrl = URL(fileURLWithPath: configPath) + + let configuration = try Configuration(at: configurationFileUrl) + let creator = try ClassifierCreator(configuration: configuration) + + try await creator.run() + } +} diff --git a/Sources/Cap_Train.swift b/Sources/Cap_Train.swift deleted file mode 100644 index c922dba..0000000 --- a/Sources/Cap_Train.swift +++ /dev/null @@ -1,14 +0,0 @@ -// The Swift Programming Language -// https://docs.swift.org/swift-book -// -// Swift Argument Parser -// https://swiftpackageindex.com/apple/swift-argument-parser/documentation - -import ArgumentParser - -@main -struct Cap_Train: ParsableCommand { - mutating func run() throws { - print("Hello, world!") - } -} \ No newline at end of file diff --git a/Sources/ClassifierCreator.swift b/Sources/ClassifierCreator.swift new file mode 100644 index 0000000..5ee72a4 --- /dev/null +++ b/Sources/ClassifierCreator.swift @@ -0,0 +1,463 @@ +import Foundation +import CreateML + +final class ClassifierCreator { + + let server: URL + + let configuration: Configuration + + let imageDirectory: URL + + let thumbnailDirectory: URL + + let classifierUrl: URL + + let df = DateFormatter() + + + // MARK: Step 1: Load configuration + + init(configuration: Configuration) throws { + self.configuration = configuration + self.server = try configuration.serverUrl() + let contentDirectory = URL(fileURLWithPath: configuration.contentFolder) + self.imageDirectory = contentDirectory.appendingPathComponent("images") + self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel") + self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails") + df.dateFormat = "yy-MM-dd-HH-mm-ss" + } + + // MARK: Main function + + func run() async throws { + let imagesSnapshotDate = Date() + let (classes, changedImageCount, changedMainImages) = try await loadImages() + + guard !classes.isEmpty else { + print("[INFO] No image classes found, exiting...") + return + } + + guard changedImageCount > 0 else { + print("[INFO] No changed images, so no new classifier trained") + await createThumbnails(changed: changedMainImages) + print("[INFO] Done") + return + } + + let classifierVersion = try await getClassifierVersion() + let newVersion = classifierVersion + 1 + + print("[INFO] Image directory: \(imageDirectory.absoluteURL.path)") + print("[INFO] Model path: \(classifierUrl.path)") + print("[INFO] Version: \(newVersion)") + print("[INFO] Classes: \(classes.count)") + print("[INFO] Iterations: \(configuration.trainingIterations)") + + try trainModel() + try await upload(version: newVersion) + try await upload(classes: classes, lastUpdate: imagesSnapshotDate) + await createThumbnails(changed: changedMainImages) + print("[INFO] Done") + } + + // MARK: Step 2: Load changed images + + func loadImages() async throws -> (classes: [Int], changedImageCount: Int, changedMainImages: [Int]) { + guard createFolderIfMissing(imageDirectory) else { + throw TrainingError.mainImageFolderNotCreated + } + let imageCounts = await getImageCounts() + let missingImageList: [CapImage] = imageCounts + .sorted { $0.key < $1.key } + .reduce(into: []) { list, pair in + let missingImagesForCap: [CapImage] = (0.. 0 ? " (\(imagesAlreadyLoad) already loaded)" : "" + if filteredChangeList.isEmpty { + print("[INFO] No changed images to load" + suffix) + } else { + print("[INFO] Loading \(filteredChangeList.count) changed images" + suffix) + try await loadImages(filteredChangeList) + } + + let changedMainImages = changedImageList.filter { $0.image == 0 }.map { $0.cap } + let classes = imageCounts.keys.sorted() + + // Delete any image folders not present as caps + try deleteUnnecessaryImageFolders(caps: classes) + + return (classes, missingImageList.count + changedImageList.count, changedMainImages) + } + + private func getImageCounts() async -> [Int : Int] { + guard let data: Data = await get(server.appendingPathComponent("caps.json")) else { + return [:] + } + do { + return try JSONDecoder().decode([Cap].self, from: data) + .reduce(into: [:]) { $0[$1.id] = $1.count } + } catch { + print("[ERROR] Failed to decode cap database: \(error)") + return [:] + } + } + + private func deleteUnnecessaryImageFolders(caps: [Int]) throws { + let validNames = caps.map { String(format: "%04d", $0) } + let folders: [String] + do { + folders = try FileManager.default.contentsOfDirectory(atPath: imageDirectory.path) + } catch { + print("[ERROR] Failed to get list of image folders: \(error)") + throw TrainingError.failedToGetListOfImageFolders + } + for folder in folders { + if validNames.contains(folder) { + continue + } + + // Not a valid cap folder + let url = imageDirectory.appendingPathComponent(folder) + do { + try FileManager.default.removeItem(at: url) + print("[INFO] Removed unused image folder '\(folder)'") + } catch { + print("[ERROR] Failed to delete unused image folder \(folder): \(error)") + throw TrainingError.failedToRemoveImageFolder + } + } + } + + private func imageUrl(base: URL, image: CapImage) -> URL { + base.appendingPathComponent(String(format: "%04d/%04d-%02d.jpg", image.cap, image.cap, image.image)) + } + + private func load(image: CapImage) async -> Bool { + guard createFolderIfMissing(imageDirectory.appendingPathComponent(String(format: "%04d", image.cap))) else { + return false + } + let url = imageUrl(base: server.appendingPathComponent("images"), image: image) + let tempFile: URL, response: URLResponse + do { + (tempFile, response) = try await URLSession.shared.download(from: url) + } catch { + print("[ERROR] Failed to load \(image): \(error)") + return false + } + let responseCode = (response as! HTTPURLResponse).statusCode + guard responseCode == 200 else { + print("[ERROR] Failed to load \(image): Response \(responseCode)") + return false + } + do { + let localUrl = imageUrl(base: imageDirectory, image: image) + if FileManager.default.fileExists(atPath: localUrl.path) { + try FileManager.default.removeItem(at: localUrl) + } + try FileManager.default.moveItem(at: tempFile, to: localUrl) + return true + } catch { + print("[ERROR] Failed to save \(image): \(error)") + return false + } + } + + private func loadImages(_ list: [CapImage]) async throws { + var loadedImages = 0 + await withTaskGroup(of: Bool.self) { group in + for image in list { + group.addTask { + await self.load(image: image) + } + } + for await loaded in group { + if loaded { + loadedImages += 1 + } + } + } + if loadedImages != list.count { + print("[ERROR] Only \(loadedImages) of \(list.count) images loaded") + throw TrainingError.failedToLoadImages + } + } + + func getChangedImageList() async throws -> [CapImage] { + guard let string: String = await get(server.appendingPathComponent("changes.txt")) else { + print("[ERROR] Failed to get list of changed images") + throw TrainingError.failedToGetChangedImagesList + } + + return string + .components(separatedBy: "\n") + .filter { $0 != "" } + .compactMap { + let parts = $0.components(separatedBy: ":") + guard parts.count == 3 else { + return nil + } + /* + guard let date = df.date(from: parts[0]) else { + print("[WARN] Invalid date \(parts[0]) in change list") + return nil + } + */ + guard let cap = Int(parts[1]) else { + print("[WARN] Invalid cap id \(parts[1]) in change list") + return nil + } + guard let image = Int(parts[2]) else { + print("[WARN] Invalid image id \(parts[2]) in change list") + return nil + } + return CapImage(cap: cap, image: image) + } + } + + // MARK: Step 3: Compute version + + func getClassifierVersion() async throws -> Int { + guard let string: String = await get(server.appendingPathComponent("version")) else { + print("[ERROR] Failed to get classifier version") + throw TrainingError.failedToGetClassifierVersion + } + guard let version = Int(string) else { + print("[ERROR] Invalid classifier version \(string)") + throw TrainingError.invalidClassifierVersion(string) + } + return version + } + + // MARK: Step 4: Train classifier + + func trainModel() throws { + var params = MLImageClassifier.ModelParameters(augmentation: []) + params.maxIterations = configuration.trainingIterations + let model: MLImageClassifier + do { + model = try MLImageClassifier( + trainingData: .labeledDirectories(at: imageDirectory), + parameters: params) + } catch { + print("[ERROR] Failed to create classifier: \(error)") + throw TrainingError.failedToCreateClassifier(error) + } + + print("[INFO] Saving classifier...") + do { + try model.write(to: classifierUrl) + } catch { + print("[ERROR] Failed to save model to file: \(error)") + throw TrainingError.failedToWriteClassifier(error) + } + } + + // MARK: Step 5: Upload classifier + + func upload(version: Int) async throws { + print("[INFO] Uploading classifier...") + let modelData: Data + do { + modelData = try Data(contentsOf: classifierUrl) + } catch { + print("[ERROR] Failed to read classifier data: \(error)") + throw TrainingError.failedToReadClassifierData(error) + } + + let success = await post( + url: server.appendingPathComponent("classifier/\(version)"), + body: modelData) + guard success else { + throw TrainingError.failedToUploadClassifier + } + } + + // MARK: Step 6: Update classes + + func upload(classes: [Int], lastUpdate: Date) async throws { + print("[INFO] Uploading trained classes...") + let dateString = df.string(from: lastUpdate) + let url = server.appendingPathComponent("classes/\(dateString)") + let body = classes.map(String.init).joined(separator: ",").data(using: .utf8)! + guard await post(url: url, body: body) else { + throw TrainingError.failedToUploadClassifierClasses + } + } + + // MARK: Step 7: Create thumbnails + + func createThumbnails(changed: [Int]) async { + guard checkMagickAvailability() else { + return + } + guard createFolderIfMissing(thumbnailDirectory) else { + print("[ERROR] Failed to create folder for thumbnails") + return + } + let capIdsOfMissingThumbnails = await getMissingThumbnailIds() + let all = Set(capIdsOfMissingThumbnails).union(changed) + print("[INFO] Creating \(all.count) thumbnails...") + for cap in all { + await createThumbnail(for: cap) + } + } + + func checkMagickAvailability() -> Bool { + do { + let (code, output) = try safeShell("magick --version") + guard code == 0, let version = output.components(separatedBy: "ImageMagick ").dropFirst().first? + .components(separatedBy: " ").first else { + print("[ERROR] Magick not found, install using 'brew install imagemagick'") + return false + } + print("[INFO] Using magick \(version)") + } catch { + print("[ERROR] Failed to get version of magick: (\(error))") + return false + } + return true + } + + private func getMissingThumbnailIds() async -> [Int] { + guard let string: String = await get(server.appendingPathComponent("thumbnails/missing")) else { + print("[ERROR] Failed to get missing thumbnails") + return [] + } + return string.components(separatedBy: ",").compactMap(Int.init) + } + + private func createThumbnail(for cap: Int) async { + let mainImage = CapImage(cap: cap, image: 0) + let inputUrl = imageUrl(base: imageDirectory, image: mainImage) + guard FileManager.default.fileExists(atPath: inputUrl.path) else { + print("[ERROR] Local main image not found for cap \(cap): \(inputUrl.path)") + return + } + + let output = thumbnailDirectory.appendingPathComponent(String(format: "%04d.jpg", cap)) + do { + let command = "magick convert \(inputUrl.path) -quality 70% -resize 100x100 \(output.path)" + let (code, output) = try safeShell(command) + if code != 0 { + print("Failed to create thumbnail for cap \(cap): \(output)") + return + } + } catch { + print("Failed to read created thumbnail for cap \(cap): \(error)") + return + } + + let data: Data + do { + data = try Data(contentsOf: output) + } catch { + print("Failed to read created thumbnail for cap \(cap): \(error)") + return + } + guard await post(url: server.appendingPathComponent("thumbnails/\(cap)"), body: data) else { + print("Failed to upload thumbnail for cap \(cap)") + return + } + } + + // MARK: Helper + + @discardableResult + private func safeShell(_ command: String) throws -> (code: Int32, output: String) { + let task = Process() + let pipe = Pipe() + + task.standardOutput = pipe + task.standardError = pipe + task.arguments = ["-cl", command] + task.executableURL = URL(fileURLWithPath: "/bin/zsh") + task.standardInput = nil + + try task.run() + task.waitUntilExit() + + let data = pipe.fileHandleForReading.readDataToEndOfFile() + let output = String(data: data, encoding: .utf8)! + return (task.terminationStatus, output) + } + + + private func createFolderIfMissing(_ folder: URL) -> Bool { + guard !FileManager.default.fileExists(atPath: folder.path) else { + return true + } + do { + try FileManager.default.createDirectory(at: folder, withIntermediateDirectories: true) + return true + } catch { + print("[ERROR] Failed to create directory \(folder.path): \(error)") + return false + } + } + + // MARK: Requests + + private func post(url: URL, body: Data) async -> Bool { + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.httpBody = body + request.addValue(configuration.authenticationToken, forHTTPHeaderField: "key") + return await perform(request) != nil + } + + private func perform(_ request: URLRequest) async -> Data? { + let data: Data + let response: URLResponse + do { + (data, response) = try await URLSession.shared.data(for: request) + } catch { + print("[ERROR] Request to \(request.url!.absoluteString) failed: \(error)") + return nil + } + let code = (response as! HTTPURLResponse).statusCode + guard code == 200 else { + print("[ERROR] Request to \(request.url!.absoluteString): Invalid response \(code)") + return nil + } + return data + } + + private func get(_ url: URL) async -> Data? { + await perform(URLRequest(url: url)) + } + + private func get(_ url: URL) async -> String? { + guard let data: Data = await get(url) else { + return nil + } + guard let string = String(data: data, encoding: .utf8) else { + print("[ERROR] Invalid string response \(data)") + return nil + } + return string + } +} diff --git a/Sources/Configuration.swift b/Sources/Configuration.swift new file mode 100644 index 0000000..656221d --- /dev/null +++ b/Sources/Configuration.swift @@ -0,0 +1,40 @@ +import Foundation + +struct Configuration: Codable { + + let contentFolder: String + + let trainingIterations: Int + + let serverPath: String + + let authenticationToken: String + + init(at url: URL) throws { + guard FileManager.default.fileExists(atPath: url.path) else { + print("[ERROR] No configuration at \(url.absoluteURL.path)") + throw TrainingError.configurationFileMissing + } + let data: Data + do { + data = try Data(contentsOf: url) + } catch { + print("[ERROR] Failed to load configuration data at \(url.absoluteURL.path): \(error)") + throw TrainingError.configurationFileUnreadable + } + do { + self = try JSONDecoder().decode(Configuration.self, from: data) + } catch { + print("[ERROR] Failed to decode configuration at \(url.absoluteURL.path): \(error)") + throw TrainingError.configurationFileDecodingFailed + } + } + + func serverUrl() throws -> URL { + guard let serverUrl = URL(string: serverPath) else { + print("[ERROR] Configuration: Invalid server path \(serverPath)") + throw TrainingError.invalidServerPath + } + return serverUrl + } +} diff --git a/Sources/TrainingError.swift b/Sources/TrainingError.swift new file mode 100644 index 0000000..b0648a4 --- /dev/null +++ b/Sources/TrainingError.swift @@ -0,0 +1,27 @@ +import Foundation + +enum TrainingError: Error { + + case configurationFileMissing + case configurationFileUnreadable + case configurationFileDecodingFailed + case invalidServerPath + case mainImageFolderNotCreated + case failedToLoadImages + + case failedToGetListOfImageFolders + case failedToRemoveImageFolder + + case failedToGetClassifierVersion + case invalidClassifierVersion(String) + + case failedToCreateClassifier(Error) + case failedToWriteClassifier(Error) + + case failedToGetChangedImagesList + + case failedToReadClassifierData(Error) + case failedToUploadClassifier + + case failedToUploadClassifierClasses +}