Caps-Server/Training/train.swift
2023-01-17 15:09:03 +01:00

489 lines
16 KiB
Swift

import Foundation
import Cocoa
import CreateML
struct Configuration: Codable {
let contentFolder: String
let trainingIterations: Int
let serverPath: String
let authenticationToken: String
init?(at url: URL) {
do {
let configData = try Data(contentsOf: url)
self = try JSONDecoder().decode(Configuration.self, from: configData)
} catch {
print("[ERROR] Failed to load configuration at \(url.absoluteURL.path): \(error)")
return nil
}
}
}
struct Cap: Codable {
let id: Int
let count: Int
enum CodingKeys: String, CodingKey {
case id = "i"
case count = "c"
}
}
final class ClassifierCreator {
static let configurationFileUrl = URL(fileURLWithPath: "config.json")
let server: URL
let configuration: Configuration
let imageDirectory: URL
let thumbnailDirectory: URL
let classifierUrl: URL
let df = DateFormatter()
// MARK: Step 1: Load configuration
init?() {
guard let configuration = Configuration(at: ClassifierCreator.configurationFileUrl) else {
return nil
}
self.configuration = configuration
guard let serverUrl = URL(string: configuration.serverPath) else {
print("[ERROR] Configuration: Invalid server path \(configuration.serverPath)")
return nil
}
self.server = serverUrl
let contentDirectory = URL(fileURLWithPath: configuration.contentFolder)
self.imageDirectory = contentDirectory.appendingPathComponent("images")
self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel")
self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails")
df.dateFormat = "yy-MM-dd-HH-mm-ss"
}
// MARK: Main function
func run() async {
let imagesSnapshotDate = Date()
guard let (classes, changedImageCount, changedMainImages) = await loadImages() else {
return
}
guard !classes.isEmpty else {
return
}
guard changedImageCount > 0 else {
print("[INFO] No changed images, so no new classifier trained")
await createThumbnails(changed: changedMainImages)
print("[INFO] Done")
return
}
guard let classifierVersion = await getClassifierVersion() else {
return
}
let newVersion = classifierVersion + 1
print("[INFO] Image directory: \(imageDirectory.absoluteURL.path)")
print("[INFO] Model path: \(classifierUrl.path)")
print("[INFO] Version: \(newVersion)")
print("[INFO] Classes: \(classes.count)")
print("[INFO] Iterations: \(configuration.trainingIterations)")
guard trainModel() else {
return
}
guard await upload(version: newVersion) else {
return
}
guard await upload(classes: classes, lastUpdate: imagesSnapshotDate) else {
return
}
await createThumbnails(changed: changedMainImages)
print("[INFO] Done")
}
// MARK: Step 2: Load changed images
func loadImages() async -> (classes: [Int], changedImageCount: Int, changedMainImages: [Int])? {
guard createImageFolderIfMissing() else {
return nil
}
let imageCounts = await getImageCounts()
let missingImageList: [(cap: Int, image: Int)] = imageCounts
.sorted { $0.key < $1.key }
.reduce(into: []) { list, pair in
let missingImagesForCap: [(cap: Int, image: Int)] = (0..<pair.value).compactMap { image in
let url = imageUrl(base: imageDirectory, cap: pair.key, image: image)
guard !FileManager.default.fileExists(atPath: url.path) else {
return nil
}
return (cap: pair.key, image: image)
}
list.append(contentsOf: missingImagesForCap)
}
if missingImageList.isEmpty {
print("[INFO] No missing images to load")
} else {
print("[INFO] Loading \(missingImageList.count) missing images...")
}
guard await loadImages(missingImageList) else {
return nil
}
let changedImageList = await getChangedImageList()
.filter { $0.image < imageCounts[$0.cap] ?? 0 } // Filter non-existent images
if changedImageList.isEmpty {
print("[INFO] No changed images to load")
} else {
print("[INFO] Loading \(changedImageList.count) changed images...")
}
guard await loadImages(changedImageList) else {
return nil
}
let changedMainImages = changedImageList.filter { $0.image == 0 }.map { $0.cap }
let classes = imageCounts.keys.sorted()
return (classes, missingImageList.count + changedImageList.count, changedMainImages)
}
private func createImageFolderIfMissing() -> Bool {
createFolderIfMissing(imageDirectory)
}
private func getImageCounts() async -> [Int : Int] {
guard let data: Data = await get(server.appendingPathComponent("caps.json")) else {
return [:]
}
do {
return try JSONDecoder().decode([Cap].self, from: data)
.reduce(into: [:]) { $0[$1.id] = $1.count }
} catch {
print("[ERROR] Failed to decode cap database: \(error)")
return [:]
}
}
private func imageUrl(base: URL, cap: Int, image: Int) -> URL {
base.appendingPathComponent(String(format: "%04d/%04d-%02d.jpg", cap, cap, image))
}
private func loadImage(cap: Int, image: Int) async -> Bool {
guard createFolderIfMissing(imageDirectory.appendingPathComponent(String(format: "%04d", cap))) else {
return false
}
let url = imageUrl(base: server.appendingPathComponent("images"), cap: cap, image: image)
let tempFile: URL, response: URLResponse
do {
(tempFile, response) = try await URLSession.shared.download(from: url)
} catch {
print("[ERROR] Failed to load image \(image) of cap \(cap): \(error)")
return false
}
let responseCode = (response as! HTTPURLResponse).statusCode
guard responseCode == 200 else {
print("[ERROR] Failed to load image \(image) of cap \(cap): Response \(responseCode)")
return false
}
do {
let localUrl = imageUrl(base: imageDirectory, cap: cap, image: image)
if FileManager.default.fileExists(atPath: localUrl.path) {
try FileManager.default.removeItem(at: localUrl)
}
try FileManager.default.moveItem(at: tempFile, to: localUrl)
return true
} catch {
print("[ERROR] Failed to save image \(image) of cap \(cap): \(error)")
return false
}
}
private func loadImages(_ list: [(cap: Int, image: Int)]) async -> Bool {
guard !list.isEmpty else {
return true
}
var loadedImages = 0
await withTaskGroup(of: Bool.self) { group in
for (cap, image) in list {
group.addTask {
await self.loadImage(cap: cap, image: image)
}
}
for await loaded in group {
if loaded {
loadedImages += 1
}
}
}
if loadedImages != list.count {
print("[ERROR] Only \(loadedImages) of \(list.count) images loaded")
return false
}
return true
}
func getChangedImageList() async -> [(cap: Int, image: Int)] {
guard let string: String = await get(server.appendingPathComponent("changes.txt")) else {
print("[ERROR] Failed to get list of changed images")
return []
}
return string
.components(separatedBy: "\n")
.filter { $0 != "" }
.compactMap {
let parts = $0.components(separatedBy: ":")
guard parts.count == 3 else {
return nil
}
/*
guard let date = df.date(from: parts[0]) else {
print("[WARN] Invalid date \(parts[0]) in change list")
return nil
}
*/
guard let cap = Int(parts[1]) else {
print("[WARN] Invalid cap id \(parts[1]) in change list")
return nil
}
guard let image = Int(parts[2]) else {
print("[WARN] Invalid image id \(parts[2]) in change list")
return nil
}
return (cap, image)
}
}
// MARK: Step 3: Compute version
func getClassifierVersion() async -> Int? {
guard let string: String = await get(server.appendingPathComponent("version")) else {
print("[ERROR] Failed to get classifier version")
return nil
}
guard let version = Int(string) else {
print("[ERROR] Invalid classifier version \(string)")
return nil
}
return version
}
// MARK: Step 4: Train classifier
func trainModel() -> Bool {
var params = MLImageClassifier.ModelParameters(augmentation: [])
params.maxIterations = configuration.trainingIterations
let model: MLImageClassifier
do {
model = try MLImageClassifier(
trainingData: .labeledDirectories(at: imageDirectory),
parameters: params)
} catch {
print("[ERROR] Failed to create classifier: \(error)")
return false
}
print("[INFO] Saving classifier...")
do {
try model.write(to: classifierUrl)
return true
} catch {
print("[ERROR] Failed to save model to file: \(error)")
return false
}
}
// MARK: Step 5: Upload classifier
func upload(version: Int) async -> Bool {
print("[INFO] Uploading classifier...")
let modelData: Data
do {
modelData = try Data(contentsOf: classifierUrl)
} catch {
print("[ERROR] Failed to read classifier data: \(error)")
return false
}
return await post(
url: server.appendingPathComponent("classifier/\(version)"),
body: modelData)
}
// MARK: Step 6: Update classes
func upload(classes: [Int], lastUpdate: Date) async -> Bool {
print("[INFO] Uploading trained classes...")
let dateString = df.string(from: lastUpdate)
return await post(
url: server.appendingPathComponent("classes/\(dateString)"),
body: classes.map(String.init).joined(separator: "\n").data(using: .utf8)!)
}
// MARK: Step 7: Create thumbnails
func createThumbnails(changed: [Int]) async {
guard checkMagickAvailability() else {
return
}
guard createFolderIfMissing(thumbnailDirectory) else {
print("[ERROR] Failed to create folder for thumbnails")
return
}
let capIdsOfMissingThumbnails = await getMissingThumbnailIds()
let all = Set(capIdsOfMissingThumbnails).union(changed)
print("[INFO] Creating \(all.count) thumbnails...")
for cap in all {
await createThumbnail(for: cap)
}
}
func checkMagickAvailability() -> Bool {
do {
let (code, output) = try safeShell("magick --version")
guard code == 0, let version = output.components(separatedBy: "ImageMagick ").dropFirst().first?
.components(separatedBy: " ").first else {
print("[ERROR] Magick not found, install using 'brew install imagemagick'")
return false
}
print("[INFO] Using magick \(version)")
} catch {
print("[ERROR] Failed to get version of magick: (\(error))")
return false
}
return true
}
private func getMissingThumbnailIds() async -> [Int] {
guard let string: String = await get(server.appendingPathComponent("thumbnails/missing")) else {
print("[ERROR] Failed to get missing thumbnails")
return []
}
return string.components(separatedBy: ",").compactMap(Int.init)
}
private func createThumbnail(for cap: Int) async {
let inputUrl = imageUrl(base: imageDirectory, cap: cap, image: 0)
guard FileManager.default.fileExists(atPath: inputUrl.path) else {
print("[ERROR] Local main image not found for cap \(cap): \(inputUrl.path)")
return
}
let output = thumbnailDirectory.appendingPathComponent(String(format: "%04d.jpg", cap))
do {
let command = "magick convert \(inputUrl.path) -quality 70% -resize 100x100 \(output.path)"
let (code, output) = try safeShell(command)
if code != 0 {
print("Failed to create thumbnail for cap \(cap): \(output)")
return
}
} catch {
print("Failed to read created thumbnail for cap \(cap): \(error)")
return
}
let data: Data
do {
data = try Data(contentsOf: output)
} catch {
print("Failed to read created thumbnail for cap \(cap): \(error)")
return
}
guard await post(url: server.appendingPathComponent("thumbnails/\(cap)"), body: data) else {
print("Failed to upload thumbnail for cap \(cap)")
return
}
}
// MARK: Helper
@discardableResult
private func safeShell(_ command: String) throws -> (code: Int32, output: String) {
let task = Process()
let pipe = Pipe()
task.standardOutput = pipe
task.standardError = pipe
task.arguments = ["-cl", command]
task.executableURL = URL(fileURLWithPath: "/bin/zsh")
task.standardInput = nil
try task.run()
task.waitUntilExit()
let data = pipe.fileHandleForReading.readDataToEndOfFile()
let output = String(data: data, encoding: .utf8)!
return (task.terminationStatus, output)
}
private func createFolderIfMissing(_ folder: URL) -> Bool {
guard !FileManager.default.fileExists(atPath: folder.path) else {
return true
}
do {
try FileManager.default.createDirectory(at: folder, withIntermediateDirectories: true)
return true
} catch {
print("[ERROR] Failed to create directory \(folder.path): \(error)")
return false
}
}
// MARK: Requests
private func post(url: URL, body: Data) async -> Bool {
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.httpBody = body
request.addValue(configuration.authenticationToken, forHTTPHeaderField: "key")
return await perform(request) != nil
}
private func perform(_ request: URLRequest) async -> Data? {
let data: Data
let response: URLResponse
do {
(data, response) = try await URLSession.shared.data(for: request)
} catch {
print("[ERROR] Request failed: \(error)")
return nil
}
let code = (response as! HTTPURLResponse).statusCode
guard code == 200 else {
print("[ERROR] Invalid response \(code)")
return nil
}
return data
}
private func get(_ url: URL) async -> Data? {
await perform(URLRequest(url: url))
}
private func get(_ url: URL) async -> String? {
guard let data: Data = await get(url) else {
return nil
}
guard let string = String(data: data, encoding: .utf8) else {
print("[ERROR] Invalid string response \(data)")
return nil
}
return string
}
}
await ClassifierCreator()?.run()