Caps-Train/Sources/ClassifierCreator.swift
2023-10-24 10:38:24 +02:00

526 lines
19 KiB
Swift

import Foundation
import CreateML
import Combine
final class ClassifierCreator {
private let scenePrintRevision: Int? = 2
private let sessionProgressIterations = 10
private let sessionCheckpointIterations = 100
private let sessionIterations = 1000
let server: URL
let configuration: Configuration
let imageDirectory: URL
let thumbnailDirectory: URL
let sessionDirectory: URL
let classifierUrl: URL
let df = DateFormatter()
private func print(info: String) {
Swift.print("[INFO] " + info)
}
// MARK: Step 1: Load configuration
init(configuration: Configuration) throws {
self.configuration = configuration
self.server = try configuration.serverUrl()
let contentDirectory = URL(fileURLWithPath: configuration.contentFolder)
self.imageDirectory = contentDirectory.appendingPathComponent("images")
self.sessionDirectory = contentDirectory.appendingPathComponent("session")
self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel")
self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails")
df.dateFormat = "yy-MM-dd-HH-mm-ss"
}
// MARK: Main function
func run() async throws {
let imagesSnapshotDate = Date()
let (classes, changedImageCount, changedMainImages) = try await loadImages()
guard !classes.isEmpty else {
print(info: "No image classes found, exiting...")
return
}
guard changedImageCount > 0 else {
print(info: "No changed images, so no new classifier trained")
try await createThumbnails(changed: changedMainImages)
print(info: "Done")
return
}
let classifierVersion = try await getClassifierVersion()
let newVersion = classifierVersion + 1
print(info: "Image directory: \(imageDirectory.absoluteURL.path)")
print(info: "Model path: \(classifierUrl.path)")
print(info: "Version: \(newVersion)")
print(info: "Classes: \(classes.count)")
print(info: "Iterations: \(configuration.trainingIterations)")
try await trainAndSaveModel()
try await uploadModel(version: newVersion)
try await upload(classes: classes, lastUpdate: imagesSnapshotDate)
try await createThumbnails(changed: changedMainImages)
print(info: "Done")
}
// MARK: Step 2: Load changed images
func loadImages() async throws -> (classes: [Int], changedImageCount: Int, changedMainImages: [Int]) {
do {
try createFolderIfMissing(imageDirectory)
} catch {
throw TrainingError.mainImageFolderNotCreated(error)
}
let imageCounts = try await getImageCounts()
let missingImageList: [CapImage] = imageCounts
.sorted { $0.key < $1.key }
.reduce(into: []) { list, pair in
let missingImagesForCap: [CapImage] = (0..<pair.value).compactMap { image in
let image = CapImage(cap: pair.key, image: image)
let url = imageUrl(base: imageDirectory, image: image)
guard !FileManager.default.fileExists(atPath: url.path) else {
return nil
}
return image
}
list.append(contentsOf: missingImagesForCap)
}
if missingImageList.isEmpty {
print(info: "No missing images to load")
} else {
print(info: "Loading \(missingImageList.count) missing images...")
try await loadImages(missingImageList)
}
let changedImageList = try await getChangedImageList()
let filteredChangeList = changedImageList
.filter { $0.image < imageCounts[$0.cap] ?? 0 } // Filter non-existent images
.filter { !missingImageList.contains($0) }
let imagesAlreadyLoad = changedImageList.count - filteredChangeList.count
let suffix = imagesAlreadyLoad > 0 ? " (\(imagesAlreadyLoad) already loaded)" : ""
if filteredChangeList.isEmpty {
print(info: "No changed images to load" + suffix)
} else {
print(info: "Loading \(filteredChangeList.count) changed images" + suffix)
try await loadImages(filteredChangeList)
}
let changedMainImages = changedImageList.filter { $0.image == 0 }.map { $0.cap }
let classes = imageCounts.keys.sorted()
// Delete any image folders not present as caps
try deleteUnnecessaryImageFolders(caps: classes)
return (classes, missingImageList.count + changedImageList.count, changedMainImages)
}
private func getImageCounts() async throws -> [Int : Int] {
let data: Data
do {
data = try await get(server.appendingPathComponent("caps.json"))
} catch {
throw TrainingError.failedToGetCapDatabase(error)
}
do {
return try JSONDecoder()
.decode([Cap].self, from: data)
.reduce(into: [:]) { $0[$1.id] = $1.count }
} catch {
throw TrainingError.failedToDecodeCapDatabase(error)
}
}
private func deleteUnnecessaryImageFolders(caps: [Int]) throws {
let validNames = caps.map { String(format: "%04d", $0) }
let folders: [String]
do {
folders = try FileManager.default.contentsOfDirectory(atPath: imageDirectory.path)
} catch {
throw TrainingError.failedToGetListOfImageFolders(error)
}
for folder in folders {
if validNames.contains(folder) {
continue
}
// Not a valid cap folder
let url = imageDirectory.appendingPathComponent(folder)
do {
try FileManager.default.removeItem(at: url)
print(info: "Removed unused image folder '\(folder)'")
} catch {
throw TrainingError.failedToRemoveImageFolder(folder, error)
}
}
}
private func imageUrl(base: URL, image: CapImage) -> URL {
base.appendingPathComponent(String(format: "%04d/%04d-%02d.jpg", image.cap, image.cap, image.image))
}
private func load(image: CapImage) async throws {
do {
try createFolderIfMissing(imageDirectory.appendingPathComponent(String(format: "%04d", image.cap)))
} catch {
throw TrainingError.failedToCreateImageFolder(image.cap, error)
}
let url = imageUrl(base: server.appendingPathComponent("images"), image: image)
let tempFile: URL, response: URLResponse
do {
(tempFile, response) = try await URLSession.shared.download(from: url)
} catch {
throw TrainingError.failedToLoadImage(image, error)
}
let responseCode = (response as! HTTPURLResponse).statusCode
guard responseCode == 200 else {
throw TrainingError.invalidImageRequestResponse(image, responseCode)
}
do {
let localUrl = imageUrl(base: imageDirectory, image: image)
if FileManager.default.fileExists(atPath: localUrl.path) {
try FileManager.default.removeItem(at: localUrl)
}
try FileManager.default.moveItem(at: tempFile, to: localUrl)
} catch {
throw TrainingError.failedToSaveImage(image, error)
}
}
private func loadImages(_ list: [CapImage]) async throws {
var loadedImageCount = 0
var errors = [Error]()
await withTaskGroup(of: Error?.self) { group in
for image in list {
group.addTask {
do {
try await self.load(image: image)
return nil
} catch {
return error
}
}
}
for await error in group {
if let error {
errors.append(error)
} else {
loadedImageCount += 1
}
}
}
for error in errors {
Swift.print(error.localizedDescription)
}
let expectedCount = list.count
if loadedImageCount != expectedCount {
throw TrainingError.failedToLoadImages(expected: list.count, loaded: loadedImageCount)
}
}
func getChangedImageList() async throws -> [CapImage] {
let string: String
do {
string = try await get(server.appendingPathComponent("changes.txt"))
} catch {
throw TrainingError.failedToGetChangedImagesList(error)
}
return try string
.components(separatedBy: "\n")
.map { $0.trimmingCharacters(in: .whitespaces) }
.filter { $0 != "" }
.compactMap {
let parts = $0.components(separatedBy: ":")
guard parts.count == 3,
let _ = df.date(from: parts[0]),
let cap = Int(parts[1]),
let image = Int(parts[2]) else {
throw TrainingError.invalidEntryInChangeList($0)
}
return CapImage(cap: cap, image: image)
}
}
// MARK: Step 3: Compute version
func getClassifierVersion() async throws -> Int {
let string: String
do {
string = try await get(server.appendingPathComponent("version"))
} catch {
throw TrainingError.failedToGetClassifierVersion(error)
}
guard let version = Int(string) else {
throw TrainingError.invalidClassifierVersion(string)
}
return version
}
// MARK: Step 4: Train classifier
func trainAndSaveModel() async throws {
let model = try await trainModelAsync()
//let model = try trainModelSync()
try save(model: model)
}
func trainModelAsync() async throws -> MLImageClassifier {
let params = MLImageClassifier.ModelParameters(
maxIterations: configuration.trainingIterations,
augmentation: [],
algorithm: .transferLearning(
featureExtractor: .scenePrint(revision: scenePrintRevision),
classifier: .logisticRegressor))
let sessionParameters = MLTrainingSessionParameters(
sessionDirectory: sessionDirectory,
reportInterval: sessionProgressIterations,
checkpointInterval: sessionCheckpointIterations,
iterations: sessionIterations)
var subscriptions = [AnyCancellable]()
let job: MLJob<MLImageClassifier>
do {
job = try MLImageClassifier.train(
trainingData: .labeledDirectories(at: imageDirectory),
parameters: params,
sessionParameters: sessionParameters)
} catch {
throw TrainingError.failedToCreateClassifier(error)
}
job.progress
.publisher(for: \.fractionCompleted)
.sink { completed in
Swift.print(String(format: " %.1f %% completed", completed * 100), terminator: "\r")
fflush(stdout)
//guard let progress = MLProgress(progress: job.progress) else {
// return
//}
//if let styleLoss = progress.metrics[.styleLoss] { _ = styleLoss }
//if let contentLoss = progress.metrics[.contentLoss] { _ = contentLoss }
}
.store(in: &subscriptions)
return try await withCheckedThrowingContinuation { continuation in
// Register a sink to receive the resulting model.
job.result.sink { result in
switch result {
case .finished:
break // Continuation already called with model
case .failure(let error):
continuation.resume(throwing: TrainingError.failedToCreateClassifier(error))
}
} receiveValue: { [weak self] model in
// Use model
self?.print(info: "Created model")
continuation.resume(returning: model)
}
.store(in: &subscriptions)
}
}
func trainModelSync() throws -> MLImageClassifier {
let params = MLImageClassifier.ModelParameters(
maxIterations: configuration.trainingIterations,
augmentation: [])
do {
return try MLImageClassifier(
trainingData: .labeledDirectories(at: imageDirectory),
parameters: params)
} catch {
throw TrainingError.failedToCreateClassifier(error)
}
}
private func save(model: MLImageClassifier) throws {
print(info: "Saving classifier...")
do {
try model.write(to: classifierUrl)
} catch {
throw TrainingError.failedToWriteClassifier(error)
}
}
// MARK: Step 5: Upload classifier
func uploadModel(version: Int) async throws {
print(info: "Uploading classifier...")
let modelData: Data
do {
modelData = try Data(contentsOf: classifierUrl)
} catch {
throw TrainingError.failedToReadClassifierData(error)
}
let url = server.appendingPathComponent("classifier/\(version)")
do {
try await post(url: url, body: modelData)
} catch {
throw TrainingError.failedToUploadClassifier(error)
}
}
// MARK: Step 6: Update classes
func upload(classes: [Int], lastUpdate: Date) async throws {
print(info: "Uploading trained classes...")
let dateString = df.string(from: lastUpdate)
let url = server.appendingPathComponent("classes/\(dateString)")
let body = classes.map(String.init).joined(separator: ",").data(using: .utf8)!
do {
try await post(url: url, body: body)
} catch {
throw TrainingError.failedToUploadClassifierClasses(error)
}
}
// MARK: Step 7: Create thumbnails
func createThumbnails(changed: [Int]) async throws {
try ensureMagickAvailability()
do {
try createFolderIfMissing(thumbnailDirectory)
} catch {
throw TrainingError.failedToCreateThumbnailFolder(error)
}
let capIdsOfMissingThumbnails = try await getMissingThumbnailIds()
let all = Set(capIdsOfMissingThumbnails).union(changed)
print(info: "Creating \(all.count) thumbnails...")
for cap in all {
try await createThumbnail(for: cap)
}
}
func ensureMagickAvailability() throws {
do {
let (code, output) = try safeShell("magick --version")
guard code == 0, let version = output.components(separatedBy: "ImageMagick ").dropFirst().first?
.components(separatedBy: " ").first else {
throw TrainingError.magickDependencyNotFound
}
print(info: "Using magick \(version)")
} catch {
throw TrainingError.magickDependencyCheckFailed(error)
}
}
private func getMissingThumbnailIds() async throws -> [Int] {
let string: String
do {
string = try await get(server.appendingPathComponent("thumbnails/missing"))
} catch {
throw TrainingError.failedToGetMissingThumbnailIds(error)
}
return string.components(separatedBy: ",").compactMap(Int.init)
}
private func createThumbnail(for cap: Int) async throws {
let mainImage = CapImage(cap: cap, image: 0)
let inputUrl = imageUrl(base: imageDirectory, image: mainImage)
guard FileManager.default.fileExists(atPath: inputUrl.path) else {
throw TrainingError.missingMainImage(cap)
}
let output = thumbnailDirectory.appendingPathComponent(String(format: "%04d.jpg", cap))
do {
let command = "magick convert \(inputUrl.path) -quality 70% -resize 100x100 \(output.path)"
let (code, output) = try safeShell(command)
if code != 0 {
throw TrainingError.failedToCreateThumbnail(cap, output)
}
} catch {
throw TrainingError.failedToCreateThumbnail(cap, "\(error)")
}
let data: Data
do {
data = try Data(contentsOf: output)
} catch {
throw TrainingError.failedToReadCreatedThumbnail(cap, error)
}
do {
try await post(url: server.appendingPathComponent("thumbnails/\(cap)"), body: data)
} catch {
throw TrainingError.failedToUploadCreatedThumbnail(cap, error)
}
}
// MARK: Helper
@discardableResult
private func safeShell(_ command: String) throws -> (code: Int32, output: String) {
let task = Process()
let pipe = Pipe()
task.standardOutput = pipe
task.standardError = pipe
task.arguments = ["-cl", command]
task.executableURL = URL(fileURLWithPath: "/bin/zsh")
task.standardInput = nil
try task.run()
task.waitUntilExit()
let data = pipe.fileHandleForReading.readDataToEndOfFile()
let output = String(data: data, encoding: .utf8)!
return (task.terminationStatus, output)
}
private func createFolderIfMissing(_ folder: URL) throws {
guard !FileManager.default.fileExists(atPath: folder.path) else {
return
}
try FileManager.default.createDirectory(at: folder, withIntermediateDirectories: true)
}
// MARK: Requests
private func post(url: URL, body: Data) async throws {
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.httpBody = body
request.addValue(configuration.authenticationToken, forHTTPHeaderField: "key")
_ = try await perform(request)
}
private func perform(_ request: URLRequest) async throws -> Data {
let (data, response) = try await URLSession.shared.data(for: request)
let code = (response as! HTTPURLResponse).statusCode
guard code == 200 else {
throw TrainingError.invalidResponse(request.url!, code)
}
return data
}
private func get(_ url: URL) async throws -> Data {
try await perform(URLRequest(url: url))
}
private func get(_ url: URL) async throws -> String {
let data: Data = try await get(url)
guard let string = String(data: data, encoding: .utf8) else {
throw TrainingError.invalidGetResponseData(data.count)
}
return string
}
}