import Foundation import Cocoa import CreateML final class Server { let server: URL let authentication: String init(server: URL, authentication: String) { self.server = server self.authentication = authentication } private func wait(for request: URLRequest) -> Data? { let group = DispatchGroup() group.enter() var result: Data? = nil URLSession.shared.dataTask(with: request) { data, response, _ in defer { group.leave() } let code = (response as! HTTPURLResponse).statusCode guard code == 200 else { print("[ERROR] Invalid response \(code)") return } guard let data = data else { print("[ERROR] No response data") return } result = data }.resume() group.wait() return result } func getClassifierVersion() -> Int? { let group = DispatchGroup() group.enter() let classifierVersionUrl = server.appendingPathComponent("version") guard let data = wait(for: URLRequest(url: classifierVersionUrl)) else { return nil } guard let string = String(data: data, encoding: .utf8) else { print("[ERROR] Invalid classifier version \(data)") return nil } guard let int = Int(string) else { print("[ERROR] Invalid classifier version \(string)") return nil } return int } func upload(classifier: Data, version: Int) -> Bool { let classifierUrl = server .appendingPathComponent("classifier") .appendingPathComponent("\(version)") var request = URLRequest(url: classifierUrl) request.httpMethod = "POST" request.httpBody = classifier request.addValue(authentication, forHTTPHeaderField: "key") return wait(for: request) != nil } func upload(classes: [Int]) -> Bool { let classifierUrl = server .appendingPathComponent("classes") var request = URLRequest(url: classifierUrl) request.httpMethod = "POST" request.httpBody = classes.map(String.init).joined(separator: "\n").data(using: .utf8)! request.addValue(authentication, forHTTPHeaderField: "key") return wait(for: request) != nil } } let count = CommandLine.argc guard count == 6 else { print("[ERROR] Invalid number of arguments") exit(1) } let serverPath = CommandLine.arguments[1] let authenticationKey = CommandLine.arguments[2] let imageDirectory = URL(fileURLWithPath: CommandLine.arguments[3]) let classifierUrl = URL(fileURLWithPath: CommandLine.arguments[4]) let iterationsString = CommandLine.arguments[5] guard let serverUrl = URL(string: serverPath) else { print("[ERROR] Invalid server path argument") exit(1) } guard let iterations = Int(iterationsString) else { print("[ERROR] Invalid iterations argument") exit(1) } let server = Server(server: serverUrl, authentication: authenticationKey) let classes: [Int] do { classes = try FileManager.default.contentsOfDirectory(atPath: imageDirectory.path) .compactMap(Int.init) } catch { print("[ERROR] Failed to get model classes: \(error)") exit(1) } guard let oldVersion = server.getClassifierVersion() else { print("[ERROR] Failed to get classifier version") exit(1) } let newVersion = oldVersion + 1 print("[INFO] Image directory: \(imageDirectory.path)") print("[INFO] Model path: \(classifierUrl.path)") print("[INFO] Version: \(newVersion)") print("[INFO] Classes: \(classes.count)") print("[INFO] Iterations: \(iterations)") var params = MLImageClassifier.ModelParameters(augmentation: []) params.maxIterations = iterations let model: MLImageClassifier do { model = try MLImageClassifier( trainingData: .labeledDirectories(at: imageDirectory), parameters: params) } catch { print("[ERROR] Failed to create classifier: \(error)") exit(1) } print("[INFO] Saving classifier...") do { try model.write(to: classifierUrl) } catch { print("[ERROR] Failed to save model to file: \(error)") exit(1) } print("[INFO] Uploading classifier...") let modelData = try Data(contentsOf: classifierUrl) guard server.upload(classifier: modelData, version: newVersion) else { print("[ERROR] Failed to upload classifier") exit(1) } print("[INFO] Uploading trained classes...") guard server.upload(classes: classes) else { print("[ERROR] Failed to upload classes") exit(1) } print("[INFO] Done") exit(0)