583 lines
21 KiB
Swift
583 lines
21 KiB
Swift
import Foundation
|
|
import CreateML
|
|
import Combine
|
|
|
|
final class ClassifierCreator {
|
|
|
|
private let scenePrintRevision: Int? = 2
|
|
private let sessionProgressIterations = 10
|
|
private let sessionCheckpointIterations = 100
|
|
private let sessionIterations = 1000
|
|
|
|
private let thumbnailSize = 100
|
|
|
|
let server: URL
|
|
|
|
let configuration: Configuration
|
|
|
|
/// The number of images required to include a cap in training
|
|
let minimumImagesPerCap: Int
|
|
|
|
let imageDirectory: URL
|
|
|
|
let thumbnailDirectory: URL
|
|
|
|
let sessionDirectory: URL
|
|
|
|
let classifierUrl: URL
|
|
|
|
let urlSession = URLSession(configuration: .ephemeral)
|
|
|
|
let df = DateFormatter()
|
|
|
|
private func print(info: String) {
|
|
Swift.print("[INFO] " + info)
|
|
}
|
|
|
|
// MARK: Step 1: Load configuration
|
|
|
|
init(configuration: Configuration, resume: Bool) throws {
|
|
self.configuration = configuration
|
|
self.server = try configuration.serverUrl()
|
|
self.minimumImagesPerCap = configuration.minimumImagesPerCap
|
|
let contentDirectory = URL(fileURLWithPath: configuration.contentFolder)
|
|
self.imageDirectory = contentDirectory.appendingPathComponent("images")
|
|
self.sessionDirectory = contentDirectory.appendingPathComponent("session")
|
|
self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel")
|
|
self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails")
|
|
df.dateFormat = "yy-MM-dd-HH-mm-ss"
|
|
if !resume {
|
|
try removeSessionData()
|
|
}
|
|
}
|
|
|
|
private func removeSessionData() throws {
|
|
guard FileManager.default.fileExists(atPath: sessionDirectory.path) else {
|
|
return
|
|
}
|
|
do {
|
|
try FileManager.default.removeItem(at: sessionDirectory)
|
|
print(info: "Removed training session data")
|
|
} catch {
|
|
throw TrainingError.failedToRemoveSessionFolder(error)
|
|
}
|
|
}
|
|
|
|
// MARK: Main function
|
|
|
|
func run(skipTraining: Bool) async throws {
|
|
let imagesSnapshotDate = Date()
|
|
let (classes, changedImageCount, changedMainImages) = try await loadImages()
|
|
|
|
guard !classes.isEmpty else {
|
|
print(info: "No image classes found, exiting...")
|
|
return
|
|
}
|
|
|
|
guard changedImageCount > 0 else {
|
|
print(info: "No changed images, so no new classifier trained")
|
|
try await createThumbnails(changed: changedMainImages)
|
|
print(info: "Done")
|
|
return
|
|
}
|
|
|
|
let classifierVersion = try await getClassifierVersion()
|
|
let newVersion = classifierVersion + 1
|
|
|
|
print(info: "Server: \(server.path)")
|
|
print(info: "Image directory: \(imageDirectory.absoluteURL.path)")
|
|
print(info: "Model path: \(classifierUrl.path)")
|
|
print(info: "Version: \(newVersion)")
|
|
print(info: "Classes: \(classes.count)")
|
|
print(info: "Iterations: \(configuration.trainingIterations)")
|
|
|
|
if !skipTraining {
|
|
try await trainAndSaveModel()
|
|
}
|
|
try await uploadModel(version: newVersion)
|
|
try await upload(classes: classes, lastUpdate: imagesSnapshotDate)
|
|
try await createThumbnails(changed: changedMainImages)
|
|
try removeSessionData()
|
|
print(info: "Done")
|
|
}
|
|
|
|
// MARK: Step 2: Load changed images
|
|
|
|
func loadImages() async throws -> (classes: [Int], changedImageCount: Int, changedMainImages: [Int]) {
|
|
do {
|
|
try createFolderIfMissing(imageDirectory)
|
|
} catch {
|
|
throw TrainingError.mainImageFolderNotCreated(error)
|
|
}
|
|
let imageCounts = try await getImageCounts()
|
|
.filter { id, count in
|
|
// Delete caps with small counts to ensure proper training
|
|
if count < self.minimumImagesPerCap {
|
|
print(info: "Excluding cap \(id) from training (\(count) images)")
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
let missingImageList: [CapImage] = imageCounts
|
|
.filter { $0.value >= self.minimumImagesPerCap }
|
|
.sorted { $0.key < $1.key }
|
|
.reduce(into: []) { list, pair in
|
|
let missingImagesForCap: [CapImage] = (0..<pair.value).compactMap { image in
|
|
let image = CapImage(cap: pair.key, image: image)
|
|
let url = imageUrl(base: imageDirectory, image: image)
|
|
guard !FileManager.default.fileExists(atPath: url.path) else {
|
|
return nil
|
|
}
|
|
return image
|
|
}
|
|
list.append(contentsOf: missingImagesForCap)
|
|
}
|
|
if missingImageList.isEmpty {
|
|
print(info: "No missing images to load")
|
|
} else {
|
|
print(info: "Loading \(missingImageList.count) missing images...")
|
|
try await loadImagesInBatches(missingImageList)
|
|
}
|
|
|
|
|
|
let changedImageList = try await getChangedImageList()
|
|
let filteredChangeList = changedImageList
|
|
.filter { $0.image < imageCounts[$0.cap] ?? 0 } // Filter non-existent images
|
|
.filter { !missingImageList.contains($0) }
|
|
|
|
let imagesAlreadyLoad = changedImageList.count - filteredChangeList.count
|
|
let suffix = imagesAlreadyLoad > 0 ? " (\(imagesAlreadyLoad) already loaded)" : ""
|
|
if filteredChangeList.isEmpty {
|
|
print(info: "No changed images to load" + suffix)
|
|
} else {
|
|
print(info: "Loading \(filteredChangeList.count) changed images" + suffix)
|
|
try await loadImagesInBatches(filteredChangeList)
|
|
}
|
|
|
|
let changedMainImages = changedImageList.filter { $0.image == 0 }.map { $0.cap }
|
|
let classes = imageCounts.keys.sorted()
|
|
|
|
// Delete any image folders not present as caps
|
|
try deleteUnnecessaryImageFolders(caps: classes)
|
|
|
|
return (classes, missingImageList.count + changedImageList.count, changedMainImages)
|
|
}
|
|
|
|
private func getImageCounts() async throws -> [Int : Int] {
|
|
let data: Data
|
|
do {
|
|
data = try await get(server.appendingPathComponent("caps.json"))
|
|
} catch {
|
|
throw TrainingError.failedToGetCapDatabase(error)
|
|
}
|
|
do {
|
|
return try JSONDecoder()
|
|
.decode([Cap].self, from: data)
|
|
.reduce(into: [:]) { $0[$1.id] = $1.count }
|
|
} catch {
|
|
throw TrainingError.failedToDecodeCapDatabase(error)
|
|
}
|
|
}
|
|
|
|
private func deleteUnnecessaryImageFolders(caps: [Int]) throws {
|
|
let validNames = caps.map { String(format: "%04d", $0) }
|
|
let folders: [String]
|
|
do {
|
|
folders = try FileManager.default.contentsOfDirectory(atPath: imageDirectory.path)
|
|
} catch {
|
|
throw TrainingError.failedToGetListOfImageFolders(error)
|
|
}
|
|
for folder in folders {
|
|
if validNames.contains(folder) {
|
|
continue
|
|
}
|
|
|
|
// Not a valid cap folder
|
|
let url = imageDirectory.appendingPathComponent(folder)
|
|
do {
|
|
try FileManager.default.removeItem(at: url)
|
|
print(info: "Removed item in image folder: '\(folder)'")
|
|
} catch {
|
|
throw TrainingError.failedToRemoveImageFolder(folder, error)
|
|
}
|
|
}
|
|
}
|
|
|
|
private func capFolderUrl(base: URL, cap: Int) -> URL {
|
|
base.appendingPathComponent(String(format: "%04d", cap))
|
|
}
|
|
|
|
private func imageUrl(base: URL, image: CapImage) -> URL {
|
|
capFolderUrl(base: base, cap: image.cap)
|
|
.appendingPathComponent(String(format: "%04d-%02d.jpg", image.cap, image.image))
|
|
}
|
|
|
|
private func load(image: CapImage) async throws {
|
|
do {
|
|
try createFolderIfMissing(imageDirectory.appendingPathComponent(String(format: "%04d", image.cap)))
|
|
} catch {
|
|
throw TrainingError.failedToCreateImageFolder(image.cap, error)
|
|
}
|
|
let url = imageUrl(base: server.appendingPathComponent("images"), image: image)
|
|
let tempFile: URL, response: URLResponse
|
|
do {
|
|
(tempFile, response) = try await urlSession.download(from: url)
|
|
} catch {
|
|
throw TrainingError.failedToLoadImage(image, error)
|
|
}
|
|
let responseCode = (response as! HTTPURLResponse).statusCode
|
|
guard responseCode == 200 else {
|
|
throw TrainingError.invalidImageRequestResponse(image, responseCode)
|
|
}
|
|
do {
|
|
let localUrl = imageUrl(base: imageDirectory, image: image)
|
|
if FileManager.default.fileExists(atPath: localUrl.path) {
|
|
try FileManager.default.removeItem(at: localUrl)
|
|
}
|
|
try FileManager.default.moveItem(at: tempFile, to: localUrl)
|
|
} catch {
|
|
throw TrainingError.failedToSaveImage(image, error)
|
|
}
|
|
}
|
|
|
|
private func loadImagesInBatches(_ list: [CapImage], batchSize: Int = 100) async throws {
|
|
guard !list.isEmpty else {
|
|
return
|
|
}
|
|
print(info: "Loading \(list.count) images...")
|
|
var startIndex = list.startIndex
|
|
while startIndex < list.endIndex {
|
|
let endIndex = min(startIndex + batchSize, list.count)
|
|
let batch = Array(list[startIndex..<endIndex])
|
|
print(info: "Completed \(startIndex - list.startIndex) / \(list.count) images")
|
|
try await loadImageBatch(batch)
|
|
startIndex = endIndex
|
|
}
|
|
print(info: "Completed \(startIndex - list.startIndex) / \(list.count) images")
|
|
}
|
|
|
|
private func loadImageBatch(_ list: [CapImage]) async throws {
|
|
var loadedImageCount = 0
|
|
var errors = [Error]()
|
|
await withTaskGroup(of: Error?.self) { group in
|
|
for image in list {
|
|
group.addTask {
|
|
do {
|
|
try await self.load(image: image)
|
|
return nil
|
|
} catch {
|
|
return error
|
|
}
|
|
}
|
|
}
|
|
for await error in group {
|
|
if let error {
|
|
errors.append(error)
|
|
} else {
|
|
loadedImageCount += 1
|
|
}
|
|
}
|
|
}
|
|
for error in errors {
|
|
Swift.print(error.localizedDescription)
|
|
}
|
|
let expectedCount = list.count
|
|
|
|
if loadedImageCount != expectedCount {
|
|
throw TrainingError.failedToLoadImages(expected: list.count, loaded: loadedImageCount)
|
|
}
|
|
}
|
|
|
|
func getChangedImageList() async throws -> [CapImage] {
|
|
let string: String
|
|
do {
|
|
string = try await get(server.appendingPathComponent("changes.txt"))
|
|
} catch {
|
|
throw TrainingError.failedToGetChangedImagesList(error)
|
|
}
|
|
|
|
return try string
|
|
.components(separatedBy: "\n")
|
|
.map { $0.trimmingCharacters(in: .whitespaces) }
|
|
.filter { $0 != "" }
|
|
.compactMap {
|
|
let parts = $0.components(separatedBy: ":")
|
|
guard parts.count == 3,
|
|
let _ = df.date(from: parts[0]),
|
|
let cap = Int(parts[1]),
|
|
let image = Int(parts[2]) else {
|
|
throw TrainingError.invalidEntryInChangeList($0)
|
|
}
|
|
return CapImage(cap: cap, image: image)
|
|
}
|
|
}
|
|
|
|
// MARK: Step 3: Compute version
|
|
|
|
func getClassifierVersion() async throws -> Int {
|
|
let string: String
|
|
do {
|
|
string = try await get(server.appendingPathComponent("version"))
|
|
} catch {
|
|
throw TrainingError.failedToGetClassifierVersion(error)
|
|
}
|
|
guard let version = Int(string) else {
|
|
throw TrainingError.invalidClassifierVersion(string)
|
|
}
|
|
return version
|
|
}
|
|
|
|
// MARK: Step 4: Train classifier
|
|
|
|
func trainAndSaveModel() async throws {
|
|
let model = try await trainModelAsync()
|
|
//let model = try trainModelSync()
|
|
try save(model: model)
|
|
}
|
|
|
|
func trainModelAsync() async throws -> MLImageClassifier {
|
|
let params = MLImageClassifier.ModelParameters(
|
|
maxIterations: configuration.trainingIterations,
|
|
augmentation: [],
|
|
algorithm: .transferLearning(
|
|
featureExtractor: .scenePrint(revision: scenePrintRevision),
|
|
classifier: .logisticRegressor))
|
|
let sessionParameters = MLTrainingSessionParameters(
|
|
sessionDirectory: sessionDirectory,
|
|
reportInterval: sessionProgressIterations,
|
|
checkpointInterval: sessionCheckpointIterations,
|
|
iterations: sessionIterations)
|
|
|
|
var subscriptions = [AnyCancellable]()
|
|
|
|
let job: MLJob<MLImageClassifier>
|
|
do {
|
|
job = try MLImageClassifier.train(
|
|
trainingData: .labeledDirectories(at: imageDirectory),
|
|
parameters: params,
|
|
sessionParameters: sessionParameters)
|
|
} catch {
|
|
throw TrainingError.failedToCreateClassifier(error)
|
|
}
|
|
|
|
job.progress
|
|
.publisher(for: \.fractionCompleted)
|
|
.sink { completed in
|
|
Swift.print(String(format: " %.1f %% completed", completed * 100), terminator: "\r")
|
|
fflush(stdout)
|
|
//guard let progress = MLProgress(progress: job.progress) else {
|
|
// return
|
|
//}
|
|
//if let styleLoss = progress.metrics[.styleLoss] { _ = styleLoss }
|
|
//if let contentLoss = progress.metrics[.contentLoss] { _ = contentLoss }
|
|
|
|
}
|
|
.store(in: &subscriptions)
|
|
|
|
return try await withCheckedThrowingContinuation { continuation in
|
|
// Register a sink to receive the resulting model.
|
|
job.result.sink { result in
|
|
switch result {
|
|
case .finished:
|
|
break // Continuation already called with model
|
|
case .failure(let error):
|
|
continuation.resume(throwing: TrainingError.failedToCreateClassifier(error))
|
|
}
|
|
} receiveValue: { [weak self] model in
|
|
// Use model
|
|
self?.print(info: "Created model")
|
|
continuation.resume(returning: model)
|
|
}
|
|
.store(in: &subscriptions)
|
|
}
|
|
}
|
|
|
|
func trainModelSync() throws -> MLImageClassifier {
|
|
let params = MLImageClassifier.ModelParameters(
|
|
maxIterations: configuration.trainingIterations,
|
|
augmentation: [])
|
|
|
|
do {
|
|
return try MLImageClassifier(
|
|
trainingData: .labeledDirectories(at: imageDirectory),
|
|
parameters: params)
|
|
} catch {
|
|
throw TrainingError.failedToCreateClassifier(error)
|
|
}
|
|
}
|
|
|
|
private func save(model: MLImageClassifier) throws {
|
|
print(info: "Saving classifier...")
|
|
do {
|
|
try model.write(to: classifierUrl)
|
|
} catch {
|
|
throw TrainingError.failedToWriteClassifier(error)
|
|
}
|
|
}
|
|
|
|
// MARK: Step 5: Upload classifier
|
|
|
|
func uploadModel(version: Int) async throws {
|
|
print(info: "Uploading classifier...")
|
|
let modelData: Data
|
|
do {
|
|
modelData = try Data(contentsOf: classifierUrl)
|
|
} catch {
|
|
throw TrainingError.failedToReadClassifierData(error)
|
|
}
|
|
|
|
let url = server.appendingPathComponent("classifier/\(version)")
|
|
do {
|
|
try await post(url: url, body: modelData)
|
|
} catch {
|
|
throw TrainingError.failedToUploadClassifier(error)
|
|
}
|
|
}
|
|
|
|
// MARK: Step 6: Update classes
|
|
|
|
func upload(classes: [Int], lastUpdate: Date) async throws {
|
|
print(info: "Uploading trained classes...")
|
|
let dateString = df.string(from: lastUpdate)
|
|
let url = server.appendingPathComponent("classes/\(dateString)")
|
|
let body = classes.map(String.init).joined(separator: ",").data(using: .utf8)!
|
|
do {
|
|
try await post(url: url, body: body)
|
|
} catch {
|
|
throw TrainingError.failedToUploadClassifierClasses(error)
|
|
}
|
|
}
|
|
|
|
// MARK: Step 7: Create thumbnails
|
|
|
|
func createThumbnails(changed: [Int]) async throws {
|
|
try ensureMagickAvailability()
|
|
do {
|
|
try createFolderIfMissing(thumbnailDirectory)
|
|
} catch {
|
|
throw TrainingError.failedToCreateThumbnailFolder(error)
|
|
}
|
|
let capIdsOfMissingThumbnails = try await getMissingThumbnailIds()
|
|
let all = Set(capIdsOfMissingThumbnails).union(changed)
|
|
print(info: "Creating \(all.count) thumbnails...")
|
|
for cap in all {
|
|
try await createThumbnail(for: cap)
|
|
}
|
|
}
|
|
|
|
func ensureMagickAvailability() throws {
|
|
do {
|
|
let (code, output) = try safeShell("magick --version")
|
|
guard code == 0, let version = output.components(separatedBy: "ImageMagick ").dropFirst().first?
|
|
.components(separatedBy: " ").first else {
|
|
throw TrainingError.magickDependencyNotFound
|
|
}
|
|
print(info: "Using magick \(version)")
|
|
} catch {
|
|
throw TrainingError.magickDependencyCheckFailed(error)
|
|
}
|
|
}
|
|
|
|
private func getMissingThumbnailIds() async throws -> [Int] {
|
|
let string: String
|
|
do {
|
|
string = try await get(server.appendingPathComponent("thumbnails/missing"))
|
|
} catch {
|
|
throw TrainingError.failedToGetMissingThumbnailIds(error)
|
|
}
|
|
return string.components(separatedBy: ",").compactMap(Int.init)
|
|
}
|
|
|
|
private func createThumbnail(for cap: Int) async throws {
|
|
let mainImage = CapImage(cap: cap, image: 0)
|
|
let inputUrl = imageUrl(base: imageDirectory, image: mainImage)
|
|
guard FileManager.default.fileExists(atPath: inputUrl.path) else {
|
|
throw TrainingError.missingMainImage(cap)
|
|
}
|
|
|
|
let output = thumbnailDirectory.appendingPathComponent(String(format: "%04d.jpg", cap))
|
|
do {
|
|
let command = "magick convert \(inputUrl.path) -quality 70% -resize \(thumbnailSize)x\(thumbnailSize) \(output.path)"
|
|
let (code, output) = try safeShell(command)
|
|
if code != 0 {
|
|
throw TrainingError.failedToCreateThumbnail(cap, output)
|
|
}
|
|
} catch {
|
|
throw TrainingError.failedToCreateThumbnail(cap, "\(error)")
|
|
}
|
|
|
|
let data: Data
|
|
do {
|
|
data = try Data(contentsOf: output)
|
|
} catch {
|
|
throw TrainingError.failedToReadCreatedThumbnail(cap, error)
|
|
}
|
|
do {
|
|
try await post(url: server.appendingPathComponent("thumbnails/\(cap)"), body: data)
|
|
} catch {
|
|
throw TrainingError.failedToUploadCreatedThumbnail(cap, error)
|
|
}
|
|
}
|
|
|
|
// MARK: Helper
|
|
|
|
@discardableResult
|
|
private func safeShell(_ command: String) throws -> (code: Int32, output: String) {
|
|
let task = Process()
|
|
let pipe = Pipe()
|
|
|
|
task.standardOutput = pipe
|
|
task.standardError = pipe
|
|
task.arguments = ["-cl", command]
|
|
task.executableURL = URL(fileURLWithPath: "/bin/zsh")
|
|
task.standardInput = nil
|
|
|
|
try task.run()
|
|
task.waitUntilExit()
|
|
|
|
let data = pipe.fileHandleForReading.readDataToEndOfFile()
|
|
let output = String(data: data, encoding: .utf8)!
|
|
return (task.terminationStatus, output)
|
|
}
|
|
|
|
|
|
private func createFolderIfMissing(_ folder: URL) throws {
|
|
guard !FileManager.default.fileExists(atPath: folder.path) else {
|
|
return
|
|
}
|
|
try FileManager.default.createDirectory(at: folder, withIntermediateDirectories: true)
|
|
}
|
|
|
|
// MARK: Requests
|
|
|
|
private func post(url: URL, body: Data) async throws {
|
|
var request = URLRequest(url: url)
|
|
request.httpMethod = "POST"
|
|
request.httpBody = body
|
|
request.addValue(configuration.authenticationToken, forHTTPHeaderField: "key")
|
|
_ = try await perform(request)
|
|
}
|
|
|
|
private func perform(_ request: URLRequest) async throws -> Data {
|
|
let (data, response) = try await urlSession.data(for: request)
|
|
let code = (response as! HTTPURLResponse).statusCode
|
|
guard code == 200 else {
|
|
throw TrainingError.invalidResponse(request.url!, code)
|
|
}
|
|
return data
|
|
}
|
|
|
|
private func get(_ url: URL) async throws -> Data {
|
|
try await perform(URLRequest(url: url))
|
|
}
|
|
|
|
private func get(_ url: URL) async throws -> String {
|
|
let data: Data = try await get(url)
|
|
|
|
guard let string = String(data: data, encoding: .utf8) else {
|
|
throw TrainingError.invalidGetResponseData(data.count)
|
|
}
|
|
return string
|
|
}
|
|
}
|