Switch to pure Swift training script

This commit is contained in:
Christoph Hagen 2023-01-15 11:21:47 +01:00
parent 299dba0da4
commit d516c1acd6
7 changed files with 403 additions and 200 deletions

2
.gitignore vendored
View File

@ -13,5 +13,5 @@ Public/classifier.version
Public/classifier.mlmodel Public/classifier.mlmodel
Public/caps.json Public/caps.json
Training/backup/ Training/backup/
Training/config.sh Training/config.json
Public/thumbnails Public/thumbnails

View File

@ -22,6 +22,8 @@ final class CapServer: ServerOwner {
private let fm = FileManager.default private let fm = FileManager.default
private let changedImageEntryDateFormatter: DateFormatter
// MARK: Caps // MARK: Caps
private var writers: Set<String> private var writers: Set<String>
@ -77,6 +79,8 @@ final class CapServer: ServerOwner {
self.classifierFile = folder.appendingPathComponent("classifier.mlmodel") self.classifierFile = folder.appendingPathComponent("classifier.mlmodel")
self.changedImagesFile = folder.appendingPathComponent("changes.txt") self.changedImagesFile = folder.appendingPathComponent("changes.txt")
self.writers = Set(writers) self.writers = Set(writers)
self.changedImageEntryDateFormatter = DateFormatter()
changedImageEntryDateFormatter.dateFormat = "yy-MM-dd-HH-mm-ss"
} }
func loadData() throws { func loadData() throws {
@ -303,9 +307,7 @@ final class CapServer: ServerOwner {
unwrittenImageChanges = entries unwrittenImageChanges = entries
try? handle.close() try? handle.close()
} }
let df = DateFormatter() let dateString = changedImageEntryDateFormatter.string(from: Date())
df.dateFormat = "yy-MM-dd-HH-mm-ss"
let dateString = df.string(from: Date())
while let entry = entries.popLast() { while let entry = entries.popLast() {
let content = "\(dateString):\(entry.cap):\(entry.image)\n".data(using: .utf8)! let content = "\(dateString):\(entry.cap):\(entry.image)\n".data(using: .utf8)!
try handle.write(contentsOf: content) try handle.write(contentsOf: content)
@ -333,11 +335,25 @@ final class CapServer: ServerOwner {
} }
} }
func emptyChangedImageListFile() { func removeAllEntriesInImageChangeList(before date: Date) {
do { do {
try Data().write(to: changedImagesFile) try String(contentsOf: changedImagesFile)
.components(separatedBy: "\n")
.filter { $0 != "" }
.compactMap { line -> String? in
guard let entryDate = changedImageEntryDateFormatter.date(from: line.components(separatedBy: ":").first!) else {
return nil
}
guard entryDate > date else {
return nil
}
return line
}
.joined(separator: "\n")
.data(using: .utf8)!
.write(to: changedImagesFile)
} catch { } catch {
log("Failed to empty changed images file: \(error)") log("Failed to update changed images file: \(error)")
} }
} }

View File

@ -4,6 +4,13 @@ import Foundation
/// The decoder to extract caps from JSON payloads given to the `cap` route. /// The decoder to extract caps from JSON payloads given to the `cap` route.
private let decoder = JSONDecoder() private let decoder = JSONDecoder()
/// The date formatter to decode dates in requests
private let dateFormatter: DateFormatter = {
let df = DateFormatter()
df.dateFormat = "yy-MM-dd-HH-mm-ss"
return df
}()
private func authorize(_ request: Request) throws { private func authorize(_ request: Request) throws {
guard let key = request.headers.first(name: "key") else { guard let key = request.headers.first(name: "key") else {
throw Abort(.badRequest) // 400 throw Abort(.badRequest) // 400
@ -65,7 +72,16 @@ func routes(_ app: Application) {
} }
// Update the trained classes // Update the trained classes
app.postCatching("classes") { request in app.postCatching("classes", ":date") { request in
guard let dateString = request.parameters.get("date") else {
log("Invalid parameter for date")
throw Abort(.badRequest)
}
guard let date = dateFormatter.date(from: dateString) else {
log("Invalid date specification")
throw Abort(.badRequest)
}
try authorize(request) try authorize(request)
guard let buffer = request.body.data else { guard let buffer = request.body.data else {
log("Missing body data: \(request.body.description)") log("Missing body data: \(request.body.description)")
@ -77,6 +93,6 @@ func routes(_ app: Application) {
throw CapError.invalidBody throw CapError.invalidBody
} }
server.updateTrainedClasses(content: content) server.updateTrainedClasses(content: content)
server.emptyChangedImageListFile() server.removeAllEntriesInImageChangeList(before: date)
} }
} }

View File

@ -0,0 +1,7 @@
{
"imageDirectory": "../Public/images",
"classifierModelPath": "../Public/classifier.mlmodel",
"trainingIterations": 20,
"serverPath": "https://mydomain.com/caps",
"authenticationToken": "mysecretkey",
}

View File

@ -1,10 +0,0 @@
IMAGE_DIR="../Public/images"
MODEL_FILE="../Public/classifier.mlmodel"
TRAINING_ITERATIONS="20"
SSH_PORT="22"
SERVER="pi@mydomain.com"
SERVER_ROOT_PATH="/caps/Public"
SERVER_PATH="https://mydomain.com/caps"
SERVER_AUTH="mysecretkey"

View File

@ -1,48 +0,0 @@
#!/bin/bash
########### SUDO COMMANDS WITHOUT PASSWORD #############################################
#
# Specific commands must be executable without sudo
#
# Add to sudoers file (sudo visudo):
#
# Disable password for specific commands
# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/chmod -R 755 /data/public/capserver/images
# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/mv /home/pi/classifier.mlmodel /data/public/capserver/
# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/mv /home/pi/classifier.version /data/public/capserver/
# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/chown -R www-data\:www-data /data/public/capserver/
#
########################################################################################
########### EXPLANATION OF RSYNC FLAGS #################################################
#
# -h human readable output
# -v verbose output
# -r recursive
# -P print information about long-running transfers, keep partial files (—-progress --partial)
# -t preserves modification times
# -u only update if newer
#
########################################################################################
source config.sh
echo "[INFO] Ensuring permissions for images on server..."
ssh -p $SSH_PORT ${SERVER} "sudo chmod -R 755 ${SERVER_ROOT_PATH}/images"
retVal=$?
if [ $retVal -ne 0 ]; then
echo '[ERROR] Failed to change image permissions'
return $retVal
fi
echo "[INFO] Transferring images from server..."
rsync -hrut --info=progress2 -e "ssh -p ${SSH_PORT}" ${SERVER}:/${SERVER_ROOT_PATH}/images/ "${IMAGE_DIR}"
retVal=$?
if [ $retVal -ne 0 ]; then
echo '[ERROR] Failed to transfer images from server'
return $retVal
fi
swift train.swift $SERVER_PATH $SERVER_AUTH $IMAGE_DIR $MODEL_FILE $TRAINING_ITERATIONS

View File

@ -2,126 +2,289 @@ import Foundation
import Cocoa import Cocoa
import CreateML import CreateML
final class Server { struct Configuration: Codable {
let imageDirectory: String
let classifierModelPath: String
let trainingIterations: Int
let serverPath: String
let authenticationToken: String
init?(at url: URL) {
do {
let configData = try Data(contentsOf: url)
self = try JSONDecoder().decode(Configuration.self, from: configData)
} catch {
print("[ERROR] Failed to load configuration at \(url.absoluteURL.path): \(error)")
return nil
}
}
}
struct Cap: Codable {
let id: Int
let count: Int
enum CodingKeys: String, CodingKey {
case id = "i"
case count = "c"
}
}
final class ClassifierCreator {
static let configurationFileUrl = URL(fileURLWithPath: "config.json")
let server: URL let server: URL
let authentication: String let configuration: Configuration
init(server: URL, authentication: String) { let imageDirectory: URL
self.server = server
self.authentication = authentication
}
private func wait(for request: URLRequest) -> Data? { let classifierUrl: URL
let group = DispatchGroup()
group.enter()
var result: Data? = nil
URLSession.shared.dataTask(with: request) { data, response, _ in
defer { group.leave() }
let code = (response as! HTTPURLResponse).statusCode
guard code == 200 else {
print("[ERROR] Invalid response \(code)")
return
}
guard let data = data else {
print("[ERROR] No response data")
return
}
result = data
}.resume()
group.wait()
return result
}
func getClassifierVersion() -> Int? { let df = DateFormatter()
let group = DispatchGroup()
group.enter()
let classifierVersionUrl = server.appendingPathComponent("version") // MARK: Step 1: Load configuration
guard let data = wait(for: URLRequest(url: classifierVersionUrl)) else {
init?() {
guard let configuration = Configuration(at: ClassifierCreator.configurationFileUrl) else {
return nil return nil
} }
guard let string = String(data: data, encoding: .utf8) else { self.configuration = configuration
print("[ERROR] Invalid classifier version \(data)") guard let serverUrl = URL(string: configuration.serverPath) else {
print("[ERROR] Configuration: Invalid server path \(configuration.serverPath)")
return nil return nil
} }
guard let int = Int(string) else { self.server = serverUrl
self.imageDirectory = URL(fileURLWithPath: configuration.imageDirectory)
self.classifierUrl = URL(fileURLWithPath: configuration.classifierModelPath)
df.dateFormat = "yy-MM-dd-HH-mm-ss"
}
// MARK: Main function
func run() async {
let imagesSnapshotDate = Date()
guard let (classes, changedImageCount) = await loadImages() else {
return
}
guard !classes.isEmpty else {
return
}
guard changedImageCount > 0 else {
print("[INFO] No changed images, so no new classifier trained")
return
}
guard let classifierVersion = await getClassifierVersion() else {
return
}
let newVersion = classifierVersion + 1
print("[INFO] Image directory: \(imageDirectory.absoluteURL.path)")
print("[INFO] Model path: \(configuration.classifierModelPath)")
print("[INFO] Version: \(newVersion)")
print("[INFO] Classes: \(classes.count)")
print("[INFO] Iterations: \(configuration.trainingIterations)")
guard trainModel() else {
return
}
guard await upload(version: newVersion) else {
return
}
guard await upload(classes: classes, lastUpdate: imagesSnapshotDate) else {
return
}
print("[INFO] Done")
}
// MARK: Step 2: Load changed images
func loadImages() async -> (classes: [Int], changedImages: Int)? {
guard createImageFolderIfMissing() else {
return nil
}
let imageCounts = await getImageCounts()
let missingImageList: [(cap: Int, image: Int)] = imageCounts
.sorted { $0.key < $1.key }
.reduce(into: []) { list, pair in
let missingImagesForCap: [(cap: Int, image: Int)] = (0..<pair.value).compactMap { image in
let url = imageUrl(base: imageDirectory, cap: pair.key, image: image)
guard !FileManager.default.fileExists(atPath: url.path) else {
return nil
}
return (cap: pair.key, image: image)
}
list.append(contentsOf: missingImagesForCap)
}
if missingImageList.isEmpty {
print("[INFO] No missing images to load")
} else {
print("[INFO] Loading \(missingImageList.count) missing images...")
}
guard await loadImages(missingImageList) else {
return nil
}
let changedImageList = await getChangedImageList()
if changedImageList.isEmpty {
print("[INFO] No changed images to load")
} else {
print("[INFO] Loading \(changedImageList.count) changed images...")
}
guard await loadImages(changedImageList) else {
return nil
}
let classes = imageCounts.keys.sorted()
return (classes, missingImageList.count + changedImageList.count)
}
private func createImageFolderIfMissing() -> Bool {
guard !FileManager.default.fileExists(atPath: imageDirectory.path) else {
return true
}
do {
try FileManager.default.createDirectory(at: imageDirectory, withIntermediateDirectories: true)
return true
} catch {
print("[ERROR] Failed to create image directory: \(error)")
return false
}
}
private func getImageCounts() async -> [Int : Int] {
guard let data: Data = await get(server.appendingPathComponent("caps.json")) else {
return [:]
}
do {
return try JSONDecoder().decode([Cap].self, from: data)
.reduce(into: [:]) { $0[$1.id] = $1.count }
} catch {
print("[ERROR] Failed to decode cap database: \(error)")
return [:]
}
}
private func imageUrl(base: URL, cap: Int, image: Int) -> URL {
base.appendingPathComponent(String(format: "%04d/%04d-%02d.jpg", cap, cap, image))
}
private func loadImage(cap: Int, image: Int) async -> Bool {
let url = imageUrl(base: server.appendingPathComponent("images"), cap: cap, image: image)
let tempFile: URL, response: URLResponse
do {
(tempFile, response) = try await URLSession.shared.download(from: url)
} catch {
print("[ERROR] Failed to load image \(image) of cap \(cap): \(error)")
return false
}
let responseCode = (response as! HTTPURLResponse).statusCode
guard responseCode == 200 else {
print("[ERROR] Failed to load image \(image) of cap \(cap): Response \(responseCode)")
return false
}
do {
let localUrl = imageUrl(base: imageDirectory, cap: cap, image: image)
if FileManager.default.fileExists(atPath: localUrl.path) {
try FileManager.default.removeItem(at: localUrl)
}
try FileManager.default.moveItem(at: tempFile, to: localUrl)
return true
} catch {
print("[ERROR] Failed to save image \(image) of cap \(cap): \(error)")
return false
}
}
private func loadImages(_ list: [(cap: Int, image: Int)]) async -> Bool {
guard !list.isEmpty else {
return true
}
var loadedImages = 0
await withTaskGroup(of: Bool.self) { group in
for (cap, image) in list {
group.addTask {
await self.loadImage(cap: cap, image: image)
}
}
for await loaded in group {
if loaded {
loadedImages += 1
}
}
}
if loadedImages != list.count {
print("[ERROR] Only \(loadedImages) of \(list.count) images loaded")
return false
}
return true
}
func getChangedImageList() async -> [(cap: Int, image: Int)] {
guard let string: String = await get(server.appendingPathComponent("changes.txt")) else {
print("[ERROR] Failed to get list of changed images")
return []
}
return string
.components(separatedBy: "\n")
.filter { $0 != "" }
.compactMap {
let parts = $0.components(separatedBy: ":")
guard parts.count == 3 else {
return nil
}
/*
guard let date = df.date(from: parts[0]) else {
print("[WARN] Invalid date \(parts[0]) in change list")
return nil
}
*/
guard let cap = Int(parts[1]) else {
print("[WARN] Invalid cap id \(parts[1]) in change list")
return nil
}
guard let image = Int(parts[2]) else {
print("[WARN] Invalid image id \(parts[2]) in change list")
return nil
}
return (cap, image)
}
}
// MARK: Step 3: Compute version
func getClassifierVersion() async -> Int? {
guard let string: String = await get(server.appendingPathComponent("version")) else {
print("[ERROR] Failed to get classifier version")
return nil
}
guard let version = Int(string) else {
print("[ERROR] Invalid classifier version \(string)") print("[ERROR] Invalid classifier version \(string)")
return nil return nil
} }
return int return version
} }
func upload(classifier: Data, version: Int) -> Bool { // MARK: Step 4: Train classifier
let classifierUrl = server
.appendingPathComponent("classifier")
.appendingPathComponent("\(version)")
var request = URLRequest(url: classifierUrl)
request.httpMethod = "POST"
request.httpBody = classifier
request.addValue(authentication, forHTTPHeaderField: "key")
return wait(for: request) != nil
}
func upload(classes: [Int]) -> Bool {
let classifierUrl = server
.appendingPathComponent("classes")
var request = URLRequest(url: classifierUrl)
request.httpMethod = "POST"
request.httpBody = classes.map(String.init).joined(separator: "\n").data(using: .utf8)!
request.addValue(authentication, forHTTPHeaderField: "key")
return wait(for: request) != nil
}
}
let count = CommandLine.argc
guard count == 6 else {
print("[ERROR] Invalid number of arguments")
exit(1)
}
let serverPath = CommandLine.arguments[1]
let authenticationKey = CommandLine.arguments[2]
let imageDirectory = URL(fileURLWithPath: CommandLine.arguments[3])
let classifierUrl = URL(fileURLWithPath: CommandLine.arguments[4])
let iterationsString = CommandLine.arguments[5]
guard let serverUrl = URL(string: serverPath) else {
print("[ERROR] Invalid server path argument")
exit(1)
}
guard let iterations = Int(iterationsString) else {
print("[ERROR] Invalid iterations argument")
exit(1)
}
let server = Server(server: serverUrl, authentication: authenticationKey)
let classes: [Int]
do {
classes = try FileManager.default.contentsOfDirectory(atPath: imageDirectory.path)
.compactMap(Int.init)
} catch {
print("[ERROR] Failed to get model classes: \(error)")
exit(1)
}
guard let oldVersion = server.getClassifierVersion() else {
print("[ERROR] Failed to get classifier version")
exit(1)
}
let newVersion = oldVersion + 1
print("[INFO] Image directory: \(imageDirectory.path)")
print("[INFO] Model path: \(classifierUrl.path)")
print("[INFO] Version: \(newVersion)")
print("[INFO] Classes: \(classes.count)")
print("[INFO] Iterations: \(iterations)")
func trainModel() -> Bool {
var params = MLImageClassifier.ModelParameters(augmentation: []) var params = MLImageClassifier.ModelParameters(augmentation: [])
params.maxIterations = iterations params.maxIterations = configuration.trainingIterations
let model: MLImageClassifier let model: MLImageClassifier
do { do {
model = try MLImageClassifier( model = try MLImageClassifier(
@ -129,29 +292,88 @@ do {
parameters: params) parameters: params)
} catch { } catch {
print("[ERROR] Failed to create classifier: \(error)") print("[ERROR] Failed to create classifier: \(error)")
exit(1) return false
} }
print("[INFO] Saving classifier...") print("[INFO] Saving classifier...")
do { do {
try model.write(to: classifierUrl) try model.write(to: classifierUrl)
return true
} catch { } catch {
print("[ERROR] Failed to save model to file: \(error)") print("[ERROR] Failed to save model to file: \(error)")
exit(1) return false
}
} }
// MARK: Step 5: Upload classifier
func upload(version: Int) async -> Bool {
print("[INFO] Uploading classifier...") print("[INFO] Uploading classifier...")
let modelData = try Data(contentsOf: classifierUrl) let modelData: Data
guard server.upload(classifier: modelData, version: newVersion) else { do {
print("[ERROR] Failed to upload classifier") modelData = try Data(contentsOf: classifierUrl)
exit(1) } catch {
print("[ERROR] Failed to read classifier data: \(error)")
return false
} }
return await post(
url: server.appendingPathComponent("classifier/\(version)"),
body: modelData)
}
// MARK: Step 6: Update classes
func upload(classes: [Int], lastUpdate: Date) async -> Bool {
print("[INFO] Uploading trained classes...") print("[INFO] Uploading trained classes...")
guard server.upload(classes: classes) else { let dateString = df.string(from: lastUpdate)
print("[ERROR] Failed to upload classes") return await post(
exit(1) url: server.appendingPathComponent("classes/\(dateString)"),
body: classes.map(String.init).joined(separator: "\n").data(using: .utf8)!,
dateKey: lastUpdate)
} }
print("[INFO] Done") // MARK: Requests
exit(0)
private func post(url: URL, body: Data) async -> Bool {
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.httpBody = body
request.addValue(configuration.authenticationToken, forHTTPHeaderField: "key")
return await perform(request) != nil
}
private func perform(_ request: URLRequest) async -> Data? {
let data: Data
let response: URLResponse
do {
(data, response) = try await URLSession.shared.data(for: request)
} catch {
print("[ERROR] Request failed: \(error)")
return nil
}
let code = (response as! HTTPURLResponse).statusCode
guard code == 200 else {
print("[ERROR] Invalid response \(code)")
return nil
}
return data
}
private func get(_ url: URL) async -> Data? {
await perform(URLRequest(url: url))
}
private func get(_ url: URL) async -> String? {
guard let data: Data = await get(url) else {
return nil
}
guard let string = String(data: data, encoding: .utf8) else {
print("[ERROR] Invalid string response \(data)")
return nil
}
return string
}
}
await ClassifierCreator().run()