Caps-Server/Training/train.swift

158 lines
4.6 KiB
Swift
Raw Normal View History

import Foundation
2021-12-20 17:52:32 +01:00
import Cocoa
import CreateML
final class Server {
let server: URL
let authentication: String
init(server: URL, authentication: String) {
self.server = server
self.authentication = authentication
2021-12-20 17:52:32 +01:00
}
private func wait(for request: URLRequest) -> Data? {
let group = DispatchGroup()
group.enter()
var result: Data? = nil
URLSession.shared.dataTask(with: request) { data, response, _ in
defer { group.leave() }
let code = (response as! HTTPURLResponse).statusCode
guard code == 200 else {
print("[ERROR] Invalid response \(code)")
return
}
guard let data = data else {
print("[ERROR] No response data")
return
}
result = data
}.resume()
group.wait()
return result
2021-12-20 17:52:32 +01:00
}
func getClassifierVersion() -> Int? {
let group = DispatchGroup()
group.enter()
let classifierVersionUrl = server.appendingPathComponent("version")
guard let data = wait(for: URLRequest(url: classifierVersionUrl)) else {
return nil
}
guard let string = String(data: data, encoding: .utf8) else {
print("[ERROR] Invalid classifier version \(data)")
return nil
}
guard let int = Int(string) else {
print("[ERROR] Invalid classifier version \(string)")
return nil
}
return int
2021-12-20 17:52:32 +01:00
}
func upload(classifier: Data, version: Int) -> Bool {
let classifierUrl = server
.appendingPathComponent("classifier")
.appendingPathComponent("\(version)")
var request = URLRequest(url: classifierUrl)
request.httpMethod = "POST"
request.httpBody = classifier
request.addValue(authentication, forHTTPHeaderField: "key")
return wait(for: request) != nil
2021-12-20 17:52:32 +01:00
}
func upload(classes: [Int]) -> Bool {
let classifierUrl = server
.appendingPathComponent("classes")
var request = URLRequest(url: classifierUrl)
request.httpMethod = "POST"
request.httpBody = classes.map(String.init).joined(separator: "\n").data(using: .utf8)!
request.addValue(authentication, forHTTPHeaderField: "key")
return wait(for: request) != nil
2021-12-20 17:52:32 +01:00
}
}
let count = CommandLine.argc
guard count == 6 else {
print("[ERROR] Invalid number of arguments")
exit(1)
}
let serverPath = CommandLine.arguments[1]
let authenticationKey = CommandLine.arguments[2]
let imageDirectory = URL(fileURLWithPath: CommandLine.arguments[3])
let classifierUrl = URL(fileURLWithPath: CommandLine.arguments[4])
let iterationsString = CommandLine.arguments[5]
guard let serverUrl = URL(string: serverPath) else {
print("[ERROR] Invalid server path argument")
exit(1)
}
guard let iterations = Int(iterationsString) else {
print("[ERROR] Invalid iterations argument")
exit(1)
}
let server = Server(server: serverUrl, authentication: authenticationKey)
let classes: [Int]
do {
classes = try FileManager.default.contentsOfDirectory(atPath: imageDirectory.path)
.compactMap(Int.init)
} catch {
print("[ERROR] Failed to get model classes: \(error)")
exit(1)
}
guard let oldVersion = server.getClassifierVersion() else {
print("[ERROR] Failed to get classifier version")
exit(1)
}
let newVersion = oldVersion + 1
print("[INFO] Image directory: \(imageDirectory.path)")
print("[INFO] Model path: \(classifierUrl.path)")
print("[INFO] Version: \(newVersion)")
print("[INFO] Classes: \(classes.count)")
print("[INFO] Iterations: \(iterations)")
2021-12-20 17:52:32 +01:00
var params = MLImageClassifier.ModelParameters(augmentation: [])
params.maxIterations = iterations
2021-12-20 17:52:32 +01:00
let model: MLImageClassifier
do {
model = try MLImageClassifier(
trainingData: .labeledDirectories(at: imageDirectory),
parameters: params)
} catch {
print("[ERROR] Failed to create classifier: \(error)")
exit(1)
}
2021-12-20 17:52:32 +01:00
print("[INFO] Saving classifier...")
do {
try model.write(to: classifierUrl)
} catch {
print("[ERROR] Failed to save model to file: \(error)")
exit(1)
}
2021-12-20 17:52:32 +01:00
print("[INFO] Uploading classifier...")
let modelData = try Data(contentsOf: classifierUrl)
guard server.upload(classifier: modelData, version: newVersion) else {
print("[ERROR] Failed to upload classifier")
exit(1)
}
2021-12-20 17:52:32 +01:00
print("[INFO] Uploading trained classes...")
guard server.upload(classes: classes) else {
print("[ERROR] Failed to upload classes")
exit(1)
}
2021-12-20 17:52:32 +01:00
print("[INFO] Done")
exit(0)