From d516c1acd6ce07baa0359ae20331d6594e88d3a3 Mon Sep 17 00:00:00 2001 From: Christoph Hagen Date: Sun, 15 Jan 2023 11:21:47 +0100 Subject: [PATCH] Switch to pure Swift training script --- .gitignore | 2 +- Sources/App/CapServer.swift | 28 +- Sources/App/routes.swift | 20 +- Training/config_example.json | 7 + Training/config_example.sh | 10 - Training/train.sh | 48 ---- Training/train.swift | 488 +++++++++++++++++++++++++---------- 7 files changed, 403 insertions(+), 200 deletions(-) create mode 100644 Training/config_example.json delete mode 100644 Training/config_example.sh delete mode 100755 Training/train.sh diff --git a/.gitignore b/.gitignore index 8ba8496..e62264a 100644 --- a/.gitignore +++ b/.gitignore @@ -13,5 +13,5 @@ Public/classifier.version Public/classifier.mlmodel Public/caps.json Training/backup/ -Training/config.sh +Training/config.json Public/thumbnails diff --git a/Sources/App/CapServer.swift b/Sources/App/CapServer.swift index be8c37f..eb66514 100644 --- a/Sources/App/CapServer.swift +++ b/Sources/App/CapServer.swift @@ -22,6 +22,8 @@ final class CapServer: ServerOwner { private let fm = FileManager.default + private let changedImageEntryDateFormatter: DateFormatter + // MARK: Caps private var writers: Set @@ -77,6 +79,8 @@ final class CapServer: ServerOwner { self.classifierFile = folder.appendingPathComponent("classifier.mlmodel") self.changedImagesFile = folder.appendingPathComponent("changes.txt") self.writers = Set(writers) + self.changedImageEntryDateFormatter = DateFormatter() + changedImageEntryDateFormatter.dateFormat = "yy-MM-dd-HH-mm-ss" } func loadData() throws { @@ -303,9 +307,7 @@ final class CapServer: ServerOwner { unwrittenImageChanges = entries try? handle.close() } - let df = DateFormatter() - df.dateFormat = "yy-MM-dd-HH-mm-ss" - let dateString = df.string(from: Date()) + let dateString = changedImageEntryDateFormatter.string(from: Date()) while let entry = entries.popLast() { let content = "\(dateString):\(entry.cap):\(entry.image)\n".data(using: .utf8)! try handle.write(contentsOf: content) @@ -333,11 +335,25 @@ final class CapServer: ServerOwner { } } - func emptyChangedImageListFile() { + func removeAllEntriesInImageChangeList(before date: Date) { do { - try Data().write(to: changedImagesFile) + try String(contentsOf: changedImagesFile) + .components(separatedBy: "\n") + .filter { $0 != "" } + .compactMap { line -> String? in + guard let entryDate = changedImageEntryDateFormatter.date(from: line.components(separatedBy: ":").first!) else { + return nil + } + guard entryDate > date else { + return nil + } + return line + } + .joined(separator: "\n") + .data(using: .utf8)! + .write(to: changedImagesFile) } catch { - log("Failed to empty changed images file: \(error)") + log("Failed to update changed images file: \(error)") } } diff --git a/Sources/App/routes.swift b/Sources/App/routes.swift index 3010ace..d1b6950 100755 --- a/Sources/App/routes.swift +++ b/Sources/App/routes.swift @@ -4,6 +4,13 @@ import Foundation /// The decoder to extract caps from JSON payloads given to the `cap` route. private let decoder = JSONDecoder() +/// The date formatter to decode dates in requests +private let dateFormatter: DateFormatter = { + let df = DateFormatter() + df.dateFormat = "yy-MM-dd-HH-mm-ss" + return df +}() + private func authorize(_ request: Request) throws { guard let key = request.headers.first(name: "key") else { throw Abort(.badRequest) // 400 @@ -65,7 +72,16 @@ func routes(_ app: Application) { } // Update the trained classes - app.postCatching("classes") { request in + app.postCatching("classes", ":date") { request in + guard let dateString = request.parameters.get("date") else { + log("Invalid parameter for date") + throw Abort(.badRequest) + } + guard let date = dateFormatter.date(from: dateString) else { + log("Invalid date specification") + throw Abort(.badRequest) + } + try authorize(request) guard let buffer = request.body.data else { log("Missing body data: \(request.body.description)") @@ -77,6 +93,6 @@ func routes(_ app: Application) { throw CapError.invalidBody } server.updateTrainedClasses(content: content) - server.emptyChangedImageListFile() + server.removeAllEntriesInImageChangeList(before: date) } } diff --git a/Training/config_example.json b/Training/config_example.json new file mode 100644 index 0000000..3a401ec --- /dev/null +++ b/Training/config_example.json @@ -0,0 +1,7 @@ +{ + "imageDirectory": "../Public/images", + "classifierModelPath": "../Public/classifier.mlmodel", + "trainingIterations": 20, + "serverPath": "https://mydomain.com/caps", + "authenticationToken": "mysecretkey", +} \ No newline at end of file diff --git a/Training/config_example.sh b/Training/config_example.sh deleted file mode 100644 index 56f0534..0000000 --- a/Training/config_example.sh +++ /dev/null @@ -1,10 +0,0 @@ -IMAGE_DIR="../Public/images" -MODEL_FILE="../Public/classifier.mlmodel" - -TRAINING_ITERATIONS="20" -SSH_PORT="22" -SERVER="pi@mydomain.com" -SERVER_ROOT_PATH="/caps/Public" -SERVER_PATH="https://mydomain.com/caps" -SERVER_AUTH="mysecretkey" - diff --git a/Training/train.sh b/Training/train.sh deleted file mode 100755 index b2f907e..0000000 --- a/Training/train.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/bin/bash - -########### SUDO COMMANDS WITHOUT PASSWORD ############################################# -# -# Specific commands must be executable without sudo -# -# Add to sudoers file (sudo visudo): -# -# Disable password for specific commands -# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/chmod -R 755 /data/public/capserver/images -# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/mv /home/pi/classifier.mlmodel /data/public/capserver/ -# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/mv /home/pi/classifier.version /data/public/capserver/ -# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/chown -R www-data\:www-data /data/public/capserver/ -# -######################################################################################## - -########### EXPLANATION OF RSYNC FLAGS ################################################# -# -# -h human readable output -# -v verbose output -# -r recursive -# -P print information about long-running transfers, keep partial files (—-progress --partial) -# -t preserves modification times -# -u only update if newer -# -######################################################################################## - -source config.sh - -echo "[INFO] Ensuring permissions for images on server..." -ssh -p $SSH_PORT ${SERVER} "sudo chmod -R 755 ${SERVER_ROOT_PATH}/images" - -retVal=$? -if [ $retVal -ne 0 ]; then - echo '[ERROR] Failed to change image permissions' - return $retVal -fi - -echo "[INFO] Transferring images from server..." -rsync -hrut --info=progress2 -e "ssh -p ${SSH_PORT}" ${SERVER}:/${SERVER_ROOT_PATH}/images/ "${IMAGE_DIR}" - -retVal=$? -if [ $retVal -ne 0 ]; then - echo '[ERROR] Failed to transfer images from server' - return $retVal -fi - -swift train.swift $SERVER_PATH $SERVER_AUTH $IMAGE_DIR $MODEL_FILE $TRAINING_ITERATIONS diff --git a/Training/train.swift b/Training/train.swift index 331410f..f8c966a 100644 --- a/Training/train.swift +++ b/Training/train.swift @@ -2,156 +2,378 @@ import Foundation import Cocoa import CreateML -final class Server { +struct Configuration: Codable { + + let imageDirectory: String + + let classifierModelPath: String + + let trainingIterations: Int + + let serverPath: String + + let authenticationToken: String + + init?(at url: URL) { + do { + let configData = try Data(contentsOf: url) + self = try JSONDecoder().decode(Configuration.self, from: configData) + } catch { + print("[ERROR] Failed to load configuration at \(url.absoluteURL.path): \(error)") + return nil + } + } +} + +struct Cap: Codable { + + let id: Int + + let count: Int + + enum CodingKeys: String, CodingKey { + case id = "i" + case count = "c" + } +} + +final class ClassifierCreator { + + static let configurationFileUrl = URL(fileURLWithPath: "config.json") let server: URL - let authentication: String + let configuration: Configuration - init(server: URL, authentication: String) { - self.server = server - self.authentication = authentication - } + let imageDirectory: URL - private func wait(for request: URLRequest) -> Data? { - let group = DispatchGroup() - group.enter() - var result: Data? = nil - URLSession.shared.dataTask(with: request) { data, response, _ in - defer { group.leave() } - let code = (response as! HTTPURLResponse).statusCode - guard code == 200 else { - print("[ERROR] Invalid response \(code)") - return - } - guard let data = data else { - print("[ERROR] No response data") - return - } - result = data - }.resume() - group.wait() - return result - } + let classifierUrl: URL - func getClassifierVersion() -> Int? { - let group = DispatchGroup() - group.enter() - let classifierVersionUrl = server.appendingPathComponent("version") - guard let data = wait(for: URLRequest(url: classifierVersionUrl)) else { + let df = DateFormatter() + + + // MARK: Step 1: Load configuration + + init?() { + guard let configuration = Configuration(at: ClassifierCreator.configurationFileUrl) else { return nil } - guard let string = String(data: data, encoding: .utf8) else { - print("[ERROR] Invalid classifier version \(data)") + self.configuration = configuration + guard let serverUrl = URL(string: configuration.serverPath) else { + print("[ERROR] Configuration: Invalid server path \(configuration.serverPath)") return nil } - guard let int = Int(string) else { + self.server = serverUrl + self.imageDirectory = URL(fileURLWithPath: configuration.imageDirectory) + self.classifierUrl = URL(fileURLWithPath: configuration.classifierModelPath) + df.dateFormat = "yy-MM-dd-HH-mm-ss" + } + + // MARK: Main function + + func run() async { + let imagesSnapshotDate = Date() + guard let (classes, changedImageCount) = await loadImages() else { + return + } + + guard !classes.isEmpty else { + return + } + + guard changedImageCount > 0 else { + print("[INFO] No changed images, so no new classifier trained") + return + } + + guard let classifierVersion = await getClassifierVersion() else { + return + } + let newVersion = classifierVersion + 1 + + print("[INFO] Image directory: \(imageDirectory.absoluteURL.path)") + print("[INFO] Model path: \(configuration.classifierModelPath)") + print("[INFO] Version: \(newVersion)") + print("[INFO] Classes: \(classes.count)") + print("[INFO] Iterations: \(configuration.trainingIterations)") + + guard trainModel() else { + return + } + + guard await upload(version: newVersion) else { + return + } + + guard await upload(classes: classes, lastUpdate: imagesSnapshotDate) else { + return + } + + print("[INFO] Done") + } + + // MARK: Step 2: Load changed images + + func loadImages() async -> (classes: [Int], changedImages: Int)? { + guard createImageFolderIfMissing() else { + return nil + } + let imageCounts = await getImageCounts() + let missingImageList: [(cap: Int, image: Int)] = imageCounts + .sorted { $0.key < $1.key } + .reduce(into: []) { list, pair in + let missingImagesForCap: [(cap: Int, image: Int)] = (0.. Bool { + guard !FileManager.default.fileExists(atPath: imageDirectory.path) else { + return true + } + do { + try FileManager.default.createDirectory(at: imageDirectory, withIntermediateDirectories: true) + return true + } catch { + print("[ERROR] Failed to create image directory: \(error)") + return false + } + } + + private func getImageCounts() async -> [Int : Int] { + guard let data: Data = await get(server.appendingPathComponent("caps.json")) else { + return [:] + } + do { + return try JSONDecoder().decode([Cap].self, from: data) + .reduce(into: [:]) { $0[$1.id] = $1.count } + } catch { + print("[ERROR] Failed to decode cap database: \(error)") + return [:] + } + } + + private func imageUrl(base: URL, cap: Int, image: Int) -> URL { + base.appendingPathComponent(String(format: "%04d/%04d-%02d.jpg", cap, cap, image)) + } + + private func loadImage(cap: Int, image: Int) async -> Bool { + let url = imageUrl(base: server.appendingPathComponent("images"), cap: cap, image: image) + let tempFile: URL, response: URLResponse + do { + (tempFile, response) = try await URLSession.shared.download(from: url) + } catch { + print("[ERROR] Failed to load image \(image) of cap \(cap): \(error)") + return false + } + let responseCode = (response as! HTTPURLResponse).statusCode + guard responseCode == 200 else { + print("[ERROR] Failed to load image \(image) of cap \(cap): Response \(responseCode)") + return false + } + do { + let localUrl = imageUrl(base: imageDirectory, cap: cap, image: image) + if FileManager.default.fileExists(atPath: localUrl.path) { + try FileManager.default.removeItem(at: localUrl) + } + try FileManager.default.moveItem(at: tempFile, to: localUrl) + return true + } catch { + print("[ERROR] Failed to save image \(image) of cap \(cap): \(error)") + return false + } + } + + private func loadImages(_ list: [(cap: Int, image: Int)]) async -> Bool { + guard !list.isEmpty else { + return true + } + var loadedImages = 0 + await withTaskGroup(of: Bool.self) { group in + for (cap, image) in list { + group.addTask { + await self.loadImage(cap: cap, image: image) + } + } + for await loaded in group { + if loaded { + loadedImages += 1 + } + } + } + if loadedImages != list.count { + print("[ERROR] Only \(loadedImages) of \(list.count) images loaded") + return false + } + return true + } + + func getChangedImageList() async -> [(cap: Int, image: Int)] { + guard let string: String = await get(server.appendingPathComponent("changes.txt")) else { + print("[ERROR] Failed to get list of changed images") + return [] + } + + return string + .components(separatedBy: "\n") + .filter { $0 != "" } + .compactMap { + let parts = $0.components(separatedBy: ":") + guard parts.count == 3 else { + return nil + } + /* + guard let date = df.date(from: parts[0]) else { + print("[WARN] Invalid date \(parts[0]) in change list") + return nil + } + */ + guard let cap = Int(parts[1]) else { + print("[WARN] Invalid cap id \(parts[1]) in change list") + return nil + } + guard let image = Int(parts[2]) else { + print("[WARN] Invalid image id \(parts[2]) in change list") + return nil + } + return (cap, image) + } + } + + // MARK: Step 3: Compute version + + func getClassifierVersion() async -> Int? { + guard let string: String = await get(server.appendingPathComponent("version")) else { + print("[ERROR] Failed to get classifier version") + return nil + } + guard let version = Int(string) else { print("[ERROR] Invalid classifier version \(string)") return nil } - return int + return version } + + // MARK: Step 4: Train classifier + + func trainModel() -> Bool { + var params = MLImageClassifier.ModelParameters(augmentation: []) + params.maxIterations = configuration.trainingIterations + let model: MLImageClassifier + do { + model = try MLImageClassifier( + trainingData: .labeledDirectories(at: imageDirectory), + parameters: params) + } catch { + print("[ERROR] Failed to create classifier: \(error)") + return false + } - func upload(classifier: Data, version: Int) -> Bool { - let classifierUrl = server - .appendingPathComponent("classifier") - .appendingPathComponent("\(version)") - var request = URLRequest(url: classifierUrl) + print("[INFO] Saving classifier...") + do { + try model.write(to: classifierUrl) + return true + } catch { + print("[ERROR] Failed to save model to file: \(error)") + return false + } + } + + // MARK: Step 5: Upload classifier + + func upload(version: Int) async -> Bool { + print("[INFO] Uploading classifier...") + let modelData: Data + do { + modelData = try Data(contentsOf: classifierUrl) + } catch { + print("[ERROR] Failed to read classifier data: \(error)") + return false + } + + return await post( + url: server.appendingPathComponent("classifier/\(version)"), + body: modelData) + } + + // MARK: Step 6: Update classes + + func upload(classes: [Int], lastUpdate: Date) async -> Bool { + print("[INFO] Uploading trained classes...") + let dateString = df.string(from: lastUpdate) + return await post( + url: server.appendingPathComponent("classes/\(dateString)"), + body: classes.map(String.init).joined(separator: "\n").data(using: .utf8)!, + dateKey: lastUpdate) + } + + // MARK: Requests + + private func post(url: URL, body: Data) async -> Bool { + var request = URLRequest(url: url) request.httpMethod = "POST" - request.httpBody = classifier - request.addValue(authentication, forHTTPHeaderField: "key") - return wait(for: request) != nil + request.httpBody = body + request.addValue(configuration.authenticationToken, forHTTPHeaderField: "key") + return await perform(request) != nil } - - func upload(classes: [Int]) -> Bool { - let classifierUrl = server - .appendingPathComponent("classes") - var request = URLRequest(url: classifierUrl) - request.httpMethod = "POST" - request.httpBody = classes.map(String.init).joined(separator: "\n").data(using: .utf8)! - request.addValue(authentication, forHTTPHeaderField: "key") - return wait(for: request) != nil + + private func perform(_ request: URLRequest) async -> Data? { + let data: Data + let response: URLResponse + do { + (data, response) = try await URLSession.shared.data(for: request) + } catch { + print("[ERROR] Request failed: \(error)") + return nil + } + let code = (response as! HTTPURLResponse).statusCode + guard code == 200 else { + print("[ERROR] Invalid response \(code)") + return nil + } + return data + } + + private func get(_ url: URL) async -> Data? { + await perform(URLRequest(url: url)) + } + + private func get(_ url: URL) async -> String? { + guard let data: Data = await get(url) else { + return nil + } + guard let string = String(data: data, encoding: .utf8) else { + print("[ERROR] Invalid string response \(data)") + return nil + } + return string } } - -let count = CommandLine.argc -guard count == 6 else { - print("[ERROR] Invalid number of arguments") - exit(1) -} -let serverPath = CommandLine.arguments[1] -let authenticationKey = CommandLine.arguments[2] -let imageDirectory = URL(fileURLWithPath: CommandLine.arguments[3]) -let classifierUrl = URL(fileURLWithPath: CommandLine.arguments[4]) -let iterationsString = CommandLine.arguments[5] - -guard let serverUrl = URL(string: serverPath) else { - print("[ERROR] Invalid server path argument") - exit(1) -} -guard let iterations = Int(iterationsString) else { - print("[ERROR] Invalid iterations argument") - exit(1) -} - -let server = Server(server: serverUrl, authentication: authenticationKey) - -let classes: [Int] - -do { - classes = try FileManager.default.contentsOfDirectory(atPath: imageDirectory.path) - .compactMap(Int.init) -} catch { - print("[ERROR] Failed to get model classes: \(error)") - exit(1) -} - -guard let oldVersion = server.getClassifierVersion() else { - print("[ERROR] Failed to get classifier version") - exit(1) -} -let newVersion = oldVersion + 1 - -print("[INFO] Image directory: \(imageDirectory.path)") -print("[INFO] Model path: \(classifierUrl.path)") -print("[INFO] Version: \(newVersion)") -print("[INFO] Classes: \(classes.count)") -print("[INFO] Iterations: \(iterations)") - -var params = MLImageClassifier.ModelParameters(augmentation: []) -params.maxIterations = iterations - -let model: MLImageClassifier -do { - model = try MLImageClassifier( - trainingData: .labeledDirectories(at: imageDirectory), - parameters: params) -} catch { - print("[ERROR] Failed to create classifier: \(error)") - exit(1) -} - -print("[INFO] Saving classifier...") -do { - try model.write(to: classifierUrl) -} catch { - print("[ERROR] Failed to save model to file: \(error)") - exit(1) -} - -print("[INFO] Uploading classifier...") -let modelData = try Data(contentsOf: classifierUrl) -guard server.upload(classifier: modelData, version: newVersion) else { - print("[ERROR] Failed to upload classifier") - exit(1) -} - -print("[INFO] Uploading trained classes...") -guard server.upload(classes: classes) else { - print("[ERROR] Failed to upload classes") - exit(1) -} - -print("[INFO] Done") -exit(0) +await ClassifierCreator().run()