158 lines
4.6 KiB
Swift
158 lines
4.6 KiB
Swift
import Foundation
|
|
import Cocoa
|
|
import CreateML
|
|
|
|
final class Server {
|
|
|
|
let server: URL
|
|
|
|
let authentication: String
|
|
|
|
init(server: URL, authentication: String) {
|
|
self.server = server
|
|
self.authentication = authentication
|
|
}
|
|
|
|
private func wait(for request: URLRequest) -> Data? {
|
|
let group = DispatchGroup()
|
|
group.enter()
|
|
var result: Data? = nil
|
|
URLSession.shared.dataTask(with: request) { data, response, _ in
|
|
defer { group.leave() }
|
|
let code = (response as! HTTPURLResponse).statusCode
|
|
guard code == 200 else {
|
|
print("[ERROR] Invalid response \(code)")
|
|
return
|
|
}
|
|
guard let data = data else {
|
|
print("[ERROR] No response data")
|
|
return
|
|
}
|
|
result = data
|
|
}.resume()
|
|
group.wait()
|
|
return result
|
|
}
|
|
|
|
func getClassifierVersion() -> Int? {
|
|
let group = DispatchGroup()
|
|
group.enter()
|
|
let classifierVersionUrl = server.appendingPathComponent("version")
|
|
guard let data = wait(for: URLRequest(url: classifierVersionUrl)) else {
|
|
return nil
|
|
}
|
|
guard let string = String(data: data, encoding: .utf8) else {
|
|
print("[ERROR] Invalid classifier version \(data)")
|
|
return nil
|
|
}
|
|
guard let int = Int(string) else {
|
|
print("[ERROR] Invalid classifier version \(string)")
|
|
return nil
|
|
}
|
|
return int
|
|
}
|
|
|
|
func upload(classifier: Data, version: Int) -> Bool {
|
|
let classifierUrl = server
|
|
.appendingPathComponent("classifier")
|
|
.appendingPathComponent("\(version)")
|
|
var request = URLRequest(url: classifierUrl)
|
|
request.httpMethod = "POST"
|
|
request.httpBody = classifier
|
|
request.addValue(authentication, forHTTPHeaderField: "key")
|
|
return wait(for: request) != nil
|
|
}
|
|
|
|
func upload(classes: [Int]) -> Bool {
|
|
let classifierUrl = server
|
|
.appendingPathComponent("classes")
|
|
var request = URLRequest(url: classifierUrl)
|
|
request.httpMethod = "POST"
|
|
request.httpBody = classes.map(String.init).joined(separator: "\n").data(using: .utf8)!
|
|
request.addValue(authentication, forHTTPHeaderField: "key")
|
|
return wait(for: request) != nil
|
|
}
|
|
}
|
|
|
|
|
|
let count = CommandLine.argc
|
|
guard count == 6 else {
|
|
print("[ERROR] Invalid number of arguments")
|
|
exit(1)
|
|
}
|
|
let serverPath = CommandLine.arguments[1]
|
|
let authenticationKey = CommandLine.arguments[2]
|
|
let imageDirectory = URL(fileURLWithPath: CommandLine.arguments[3])
|
|
let classifierUrl = URL(fileURLWithPath: CommandLine.arguments[4])
|
|
let iterationsString = CommandLine.arguments[5]
|
|
|
|
guard let serverUrl = URL(string: serverPath) else {
|
|
print("[ERROR] Invalid server path argument")
|
|
exit(1)
|
|
}
|
|
guard let iterations = Int(iterationsString) else {
|
|
print("[ERROR] Invalid iterations argument")
|
|
exit(1)
|
|
}
|
|
|
|
let server = Server(server: serverUrl, authentication: authenticationKey)
|
|
|
|
let classes: [Int]
|
|
|
|
do {
|
|
classes = try FileManager.default.contentsOfDirectory(atPath: imageDirectory.path)
|
|
.compactMap(Int.init)
|
|
} catch {
|
|
print("[ERROR] Failed to get model classes: \(error)")
|
|
exit(1)
|
|
}
|
|
|
|
guard let oldVersion = server.getClassifierVersion() else {
|
|
print("[ERROR] Failed to get classifier version")
|
|
exit(1)
|
|
}
|
|
let newVersion = oldVersion + 1
|
|
|
|
print("[INFO] Image directory: \(imageDirectory.path)")
|
|
print("[INFO] Model path: \(classifierUrl.path)")
|
|
print("[INFO] Version: \(newVersion)")
|
|
print("[INFO] Classes: \(classes.count)")
|
|
print("[INFO] Iterations: \(iterations)")
|
|
|
|
var params = MLImageClassifier.ModelParameters(augmentation: [])
|
|
params.maxIterations = iterations
|
|
|
|
let model: MLImageClassifier
|
|
do {
|
|
model = try MLImageClassifier(
|
|
trainingData: .labeledDirectories(at: imageDirectory),
|
|
parameters: params)
|
|
} catch {
|
|
print("[ERROR] Failed to create classifier: \(error)")
|
|
exit(1)
|
|
}
|
|
|
|
print("[INFO] Saving classifier...")
|
|
do {
|
|
try model.write(to: classifierUrl)
|
|
} catch {
|
|
print("[ERROR] Failed to save model to file: \(error)")
|
|
exit(1)
|
|
}
|
|
|
|
print("[INFO] Uploading classifier...")
|
|
let modelData = try Data(contentsOf: classifierUrl)
|
|
guard server.upload(classifier: modelData, version: newVersion) else {
|
|
print("[ERROR] Failed to upload classifier")
|
|
exit(1)
|
|
}
|
|
|
|
print("[INFO] Uploading trained classes...")
|
|
guard server.upload(classes: classes) else {
|
|
print("[ERROR] Failed to upload classes")
|
|
exit(1)
|
|
}
|
|
|
|
print("[INFO] Done")
|
|
exit(0)
|