Add code from training script

This commit is contained in:
Christoph Hagen 2023-10-23 12:28:35 +02:00
parent fd484243be
commit 9a50328b93
8 changed files with 582 additions and 17 deletions

View File

@ -0,0 +1,6 @@
{
"contentDirectory": "../Public",
"trainingIterations": 20,
"serverPath": "https://mydomain.com/caps",
"authenticationToken": "mysecretkey",
}

View File

@ -1,16 +1,14 @@
// swift-tools-version: 5.9
// The swift-tools-version declares the minimum version of Swift required to build this package.
import PackageDescription
let package = Package(
name: "Cap-Train",
platforms: [.macOS(.v12)],
dependencies: [
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.2.0"),
],
targets: [
// Targets are the basic building blocks of a package, defining a module or a test suite.
// Targets can depend on other targets in this package and products from dependencies.
.executableTarget(
name: "Cap-Train",
dependencies: [

27
Sources/Cap.swift Normal file
View File

@ -0,0 +1,27 @@
import Foundation
struct Cap: Codable {
let id: Int
let count: Int
enum CodingKeys: String, CodingKey {
case id = "i"
case count = "c"
}
}
struct CapImage: Equatable {
let cap: Int
let image: Int
}
extension CapImage: CustomStringConvertible {
var description: String {
"image \(image) of cap \(cap)"
}
}

18
Sources/CapTrain.swift Normal file
View File

@ -0,0 +1,18 @@
import ArgumentParser
import Foundation
@main
struct CapTrain: AsyncParsableCommand {
@Argument(help: "The path to the configuration file")
var configPath: String
func run() async throws {
let configurationFileUrl = URL(fileURLWithPath: configPath)
let configuration = try Configuration(at: configurationFileUrl)
let creator = try ClassifierCreator(configuration: configuration)
try await creator.run()
}
}

View File

@ -1,14 +0,0 @@
// The Swift Programming Language
// https://docs.swift.org/swift-book
//
// Swift Argument Parser
// https://swiftpackageindex.com/apple/swift-argument-parser/documentation
import ArgumentParser
@main
struct Cap_Train: ParsableCommand {
mutating func run() throws {
print("Hello, world!")
}
}

View File

@ -0,0 +1,463 @@
import Foundation
import CreateML
final class ClassifierCreator {
let server: URL
let configuration: Configuration
let imageDirectory: URL
let thumbnailDirectory: URL
let classifierUrl: URL
let df = DateFormatter()
// MARK: Step 1: Load configuration
init(configuration: Configuration) throws {
self.configuration = configuration
self.server = try configuration.serverUrl()
let contentDirectory = URL(fileURLWithPath: configuration.contentFolder)
self.imageDirectory = contentDirectory.appendingPathComponent("images")
self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel")
self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails")
df.dateFormat = "yy-MM-dd-HH-mm-ss"
}
// MARK: Main function
func run() async throws {
let imagesSnapshotDate = Date()
let (classes, changedImageCount, changedMainImages) = try await loadImages()
guard !classes.isEmpty else {
print("[INFO] No image classes found, exiting...")
return
}
guard changedImageCount > 0 else {
print("[INFO] No changed images, so no new classifier trained")
await createThumbnails(changed: changedMainImages)
print("[INFO] Done")
return
}
let classifierVersion = try await getClassifierVersion()
let newVersion = classifierVersion + 1
print("[INFO] Image directory: \(imageDirectory.absoluteURL.path)")
print("[INFO] Model path: \(classifierUrl.path)")
print("[INFO] Version: \(newVersion)")
print("[INFO] Classes: \(classes.count)")
print("[INFO] Iterations: \(configuration.trainingIterations)")
try trainModel()
try await upload(version: newVersion)
try await upload(classes: classes, lastUpdate: imagesSnapshotDate)
await createThumbnails(changed: changedMainImages)
print("[INFO] Done")
}
// MARK: Step 2: Load changed images
func loadImages() async throws -> (classes: [Int], changedImageCount: Int, changedMainImages: [Int]) {
guard createFolderIfMissing(imageDirectory) else {
throw TrainingError.mainImageFolderNotCreated
}
let imageCounts = await getImageCounts()
let missingImageList: [CapImage] = imageCounts
.sorted { $0.key < $1.key }
.reduce(into: []) { list, pair in
let missingImagesForCap: [CapImage] = (0..<pair.value).compactMap { image in
let image = CapImage(cap: pair.key, image: image)
let url = imageUrl(base: imageDirectory, image: image)
guard !FileManager.default.fileExists(atPath: url.path) else {
return nil
}
return image
}
list.append(contentsOf: missingImagesForCap)
}
if missingImageList.isEmpty {
print("[INFO] No missing images to load")
} else {
print("[INFO] Loading \(missingImageList.count) missing images...")
try await loadImages(missingImageList)
}
let changedImageList = try await getChangedImageList()
let filteredChangeList = changedImageList
.filter { $0.image < imageCounts[$0.cap] ?? 0 } // Filter non-existent images
.filter { !missingImageList.contains($0) }
let imagesAlreadyLoad = changedImageList.count - filteredChangeList.count
let suffix = imagesAlreadyLoad > 0 ? " (\(imagesAlreadyLoad) already loaded)" : ""
if filteredChangeList.isEmpty {
print("[INFO] No changed images to load" + suffix)
} else {
print("[INFO] Loading \(filteredChangeList.count) changed images" + suffix)
try await loadImages(filteredChangeList)
}
let changedMainImages = changedImageList.filter { $0.image == 0 }.map { $0.cap }
let classes = imageCounts.keys.sorted()
// Delete any image folders not present as caps
try deleteUnnecessaryImageFolders(caps: classes)
return (classes, missingImageList.count + changedImageList.count, changedMainImages)
}
private func getImageCounts() async -> [Int : Int] {
guard let data: Data = await get(server.appendingPathComponent("caps.json")) else {
return [:]
}
do {
return try JSONDecoder().decode([Cap].self, from: data)
.reduce(into: [:]) { $0[$1.id] = $1.count }
} catch {
print("[ERROR] Failed to decode cap database: \(error)")
return [:]
}
}
private func deleteUnnecessaryImageFolders(caps: [Int]) throws {
let validNames = caps.map { String(format: "%04d", $0) }
let folders: [String]
do {
folders = try FileManager.default.contentsOfDirectory(atPath: imageDirectory.path)
} catch {
print("[ERROR] Failed to get list of image folders: \(error)")
throw TrainingError.failedToGetListOfImageFolders
}
for folder in folders {
if validNames.contains(folder) {
continue
}
// Not a valid cap folder
let url = imageDirectory.appendingPathComponent(folder)
do {
try FileManager.default.removeItem(at: url)
print("[INFO] Removed unused image folder '\(folder)'")
} catch {
print("[ERROR] Failed to delete unused image folder \(folder): \(error)")
throw TrainingError.failedToRemoveImageFolder
}
}
}
private func imageUrl(base: URL, image: CapImage) -> URL {
base.appendingPathComponent(String(format: "%04d/%04d-%02d.jpg", image.cap, image.cap, image.image))
}
private func load(image: CapImage) async -> Bool {
guard createFolderIfMissing(imageDirectory.appendingPathComponent(String(format: "%04d", image.cap))) else {
return false
}
let url = imageUrl(base: server.appendingPathComponent("images"), image: image)
let tempFile: URL, response: URLResponse
do {
(tempFile, response) = try await URLSession.shared.download(from: url)
} catch {
print("[ERROR] Failed to load \(image): \(error)")
return false
}
let responseCode = (response as! HTTPURLResponse).statusCode
guard responseCode == 200 else {
print("[ERROR] Failed to load \(image): Response \(responseCode)")
return false
}
do {
let localUrl = imageUrl(base: imageDirectory, image: image)
if FileManager.default.fileExists(atPath: localUrl.path) {
try FileManager.default.removeItem(at: localUrl)
}
try FileManager.default.moveItem(at: tempFile, to: localUrl)
return true
} catch {
print("[ERROR] Failed to save \(image): \(error)")
return false
}
}
private func loadImages(_ list: [CapImage]) async throws {
var loadedImages = 0
await withTaskGroup(of: Bool.self) { group in
for image in list {
group.addTask {
await self.load(image: image)
}
}
for await loaded in group {
if loaded {
loadedImages += 1
}
}
}
if loadedImages != list.count {
print("[ERROR] Only \(loadedImages) of \(list.count) images loaded")
throw TrainingError.failedToLoadImages
}
}
func getChangedImageList() async throws -> [CapImage] {
guard let string: String = await get(server.appendingPathComponent("changes.txt")) else {
print("[ERROR] Failed to get list of changed images")
throw TrainingError.failedToGetChangedImagesList
}
return string
.components(separatedBy: "\n")
.filter { $0 != "" }
.compactMap {
let parts = $0.components(separatedBy: ":")
guard parts.count == 3 else {
return nil
}
/*
guard let date = df.date(from: parts[0]) else {
print("[WARN] Invalid date \(parts[0]) in change list")
return nil
}
*/
guard let cap = Int(parts[1]) else {
print("[WARN] Invalid cap id \(parts[1]) in change list")
return nil
}
guard let image = Int(parts[2]) else {
print("[WARN] Invalid image id \(parts[2]) in change list")
return nil
}
return CapImage(cap: cap, image: image)
}
}
// MARK: Step 3: Compute version
func getClassifierVersion() async throws -> Int {
guard let string: String = await get(server.appendingPathComponent("version")) else {
print("[ERROR] Failed to get classifier version")
throw TrainingError.failedToGetClassifierVersion
}
guard let version = Int(string) else {
print("[ERROR] Invalid classifier version \(string)")
throw TrainingError.invalidClassifierVersion(string)
}
return version
}
// MARK: Step 4: Train classifier
func trainModel() throws {
var params = MLImageClassifier.ModelParameters(augmentation: [])
params.maxIterations = configuration.trainingIterations
let model: MLImageClassifier
do {
model = try MLImageClassifier(
trainingData: .labeledDirectories(at: imageDirectory),
parameters: params)
} catch {
print("[ERROR] Failed to create classifier: \(error)")
throw TrainingError.failedToCreateClassifier(error)
}
print("[INFO] Saving classifier...")
do {
try model.write(to: classifierUrl)
} catch {
print("[ERROR] Failed to save model to file: \(error)")
throw TrainingError.failedToWriteClassifier(error)
}
}
// MARK: Step 5: Upload classifier
func upload(version: Int) async throws {
print("[INFO] Uploading classifier...")
let modelData: Data
do {
modelData = try Data(contentsOf: classifierUrl)
} catch {
print("[ERROR] Failed to read classifier data: \(error)")
throw TrainingError.failedToReadClassifierData(error)
}
let success = await post(
url: server.appendingPathComponent("classifier/\(version)"),
body: modelData)
guard success else {
throw TrainingError.failedToUploadClassifier
}
}
// MARK: Step 6: Update classes
func upload(classes: [Int], lastUpdate: Date) async throws {
print("[INFO] Uploading trained classes...")
let dateString = df.string(from: lastUpdate)
let url = server.appendingPathComponent("classes/\(dateString)")
let body = classes.map(String.init).joined(separator: ",").data(using: .utf8)!
guard await post(url: url, body: body) else {
throw TrainingError.failedToUploadClassifierClasses
}
}
// MARK: Step 7: Create thumbnails
func createThumbnails(changed: [Int]) async {
guard checkMagickAvailability() else {
return
}
guard createFolderIfMissing(thumbnailDirectory) else {
print("[ERROR] Failed to create folder for thumbnails")
return
}
let capIdsOfMissingThumbnails = await getMissingThumbnailIds()
let all = Set(capIdsOfMissingThumbnails).union(changed)
print("[INFO] Creating \(all.count) thumbnails...")
for cap in all {
await createThumbnail(for: cap)
}
}
func checkMagickAvailability() -> Bool {
do {
let (code, output) = try safeShell("magick --version")
guard code == 0, let version = output.components(separatedBy: "ImageMagick ").dropFirst().first?
.components(separatedBy: " ").first else {
print("[ERROR] Magick not found, install using 'brew install imagemagick'")
return false
}
print("[INFO] Using magick \(version)")
} catch {
print("[ERROR] Failed to get version of magick: (\(error))")
return false
}
return true
}
private func getMissingThumbnailIds() async -> [Int] {
guard let string: String = await get(server.appendingPathComponent("thumbnails/missing")) else {
print("[ERROR] Failed to get missing thumbnails")
return []
}
return string.components(separatedBy: ",").compactMap(Int.init)
}
private func createThumbnail(for cap: Int) async {
let mainImage = CapImage(cap: cap, image: 0)
let inputUrl = imageUrl(base: imageDirectory, image: mainImage)
guard FileManager.default.fileExists(atPath: inputUrl.path) else {
print("[ERROR] Local main image not found for cap \(cap): \(inputUrl.path)")
return
}
let output = thumbnailDirectory.appendingPathComponent(String(format: "%04d.jpg", cap))
do {
let command = "magick convert \(inputUrl.path) -quality 70% -resize 100x100 \(output.path)"
let (code, output) = try safeShell(command)
if code != 0 {
print("Failed to create thumbnail for cap \(cap): \(output)")
return
}
} catch {
print("Failed to read created thumbnail for cap \(cap): \(error)")
return
}
let data: Data
do {
data = try Data(contentsOf: output)
} catch {
print("Failed to read created thumbnail for cap \(cap): \(error)")
return
}
guard await post(url: server.appendingPathComponent("thumbnails/\(cap)"), body: data) else {
print("Failed to upload thumbnail for cap \(cap)")
return
}
}
// MARK: Helper
@discardableResult
private func safeShell(_ command: String) throws -> (code: Int32, output: String) {
let task = Process()
let pipe = Pipe()
task.standardOutput = pipe
task.standardError = pipe
task.arguments = ["-cl", command]
task.executableURL = URL(fileURLWithPath: "/bin/zsh")
task.standardInput = nil
try task.run()
task.waitUntilExit()
let data = pipe.fileHandleForReading.readDataToEndOfFile()
let output = String(data: data, encoding: .utf8)!
return (task.terminationStatus, output)
}
private func createFolderIfMissing(_ folder: URL) -> Bool {
guard !FileManager.default.fileExists(atPath: folder.path) else {
return true
}
do {
try FileManager.default.createDirectory(at: folder, withIntermediateDirectories: true)
return true
} catch {
print("[ERROR] Failed to create directory \(folder.path): \(error)")
return false
}
}
// MARK: Requests
private func post(url: URL, body: Data) async -> Bool {
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.httpBody = body
request.addValue(configuration.authenticationToken, forHTTPHeaderField: "key")
return await perform(request) != nil
}
private func perform(_ request: URLRequest) async -> Data? {
let data: Data
let response: URLResponse
do {
(data, response) = try await URLSession.shared.data(for: request)
} catch {
print("[ERROR] Request to \(request.url!.absoluteString) failed: \(error)")
return nil
}
let code = (response as! HTTPURLResponse).statusCode
guard code == 200 else {
print("[ERROR] Request to \(request.url!.absoluteString): Invalid response \(code)")
return nil
}
return data
}
private func get(_ url: URL) async -> Data? {
await perform(URLRequest(url: url))
}
private func get(_ url: URL) async -> String? {
guard let data: Data = await get(url) else {
return nil
}
guard let string = String(data: data, encoding: .utf8) else {
print("[ERROR] Invalid string response \(data)")
return nil
}
return string
}
}

View File

@ -0,0 +1,40 @@
import Foundation
struct Configuration: Codable {
let contentFolder: String
let trainingIterations: Int
let serverPath: String
let authenticationToken: String
init(at url: URL) throws {
guard FileManager.default.fileExists(atPath: url.path) else {
print("[ERROR] No configuration at \(url.absoluteURL.path)")
throw TrainingError.configurationFileMissing
}
let data: Data
do {
data = try Data(contentsOf: url)
} catch {
print("[ERROR] Failed to load configuration data at \(url.absoluteURL.path): \(error)")
throw TrainingError.configurationFileUnreadable
}
do {
self = try JSONDecoder().decode(Configuration.self, from: data)
} catch {
print("[ERROR] Failed to decode configuration at \(url.absoluteURL.path): \(error)")
throw TrainingError.configurationFileDecodingFailed
}
}
func serverUrl() throws -> URL {
guard let serverUrl = URL(string: serverPath) else {
print("[ERROR] Configuration: Invalid server path \(serverPath)")
throw TrainingError.invalidServerPath
}
return serverUrl
}
}

View File

@ -0,0 +1,27 @@
import Foundation
enum TrainingError: Error {
case configurationFileMissing
case configurationFileUnreadable
case configurationFileDecodingFailed
case invalidServerPath
case mainImageFolderNotCreated
case failedToLoadImages
case failedToGetListOfImageFolders
case failedToRemoveImageFolder
case failedToGetClassifierVersion
case invalidClassifierVersion(String)
case failedToCreateClassifier(Error)
case failedToWriteClassifier(Error)
case failedToGetChangedImagesList
case failedToReadClassifierData(Error)
case failedToUploadClassifier
case failedToUploadClassifierClasses
}