Restructure training script, use new API

This commit is contained in:
Christoph Hagen 2022-06-24 12:13:37 +02:00
parent 631b93872e
commit 56b74117ca
3 changed files with 155 additions and 120 deletions

12
Training/config.sh Normal file
View File

@ -0,0 +1,12 @@
WORK_DIR="."
BACKUP_DIR="./backup"
IMAGE_DIR="../Public/images"
MODEL_FILE="../Public/classifier.mlmodel"
TRAINING_ITERATIONS="20"
SSH_PORT="22"
SERVER="pi@mydomain.com"
SERVER_ROOT_PATH="/caps/Public"
SERVER_PATH="https://mydomain.com/caps"
SERVER_AUTH="mysecretkey"

View File

@ -25,20 +25,7 @@
#
########################################################################################
########### PATHS ######################################################################
WORK_DIR="${HOME}/Projects/Caps/Caps-Server/Training"
BACKUP_DIR="./backup"
IMAGE_DIR="../Public/images"
VERSION_FILE="../Public/classifier.version"
MODEL_FILE="../Public/classifier.mlmodel"
TRAINING_ITERATIONS="17"
SSH_PORT="5432"
SERVER="pi@christophhagen.de"
SERVER_ROOT_PATH="/data/servers/caps/Public"
########################################################################################
source config.sh
echo "[INFO] Working in directory ${WORK_DIR}"
cd $WORK_DIR
@ -67,64 +54,4 @@ if [ $retVal -ne 0 ]; then
return $retVal
fi
echo "[INFO] Getting classifier version from server..."
scp -P $SSH_PORT ${SERVER}:/${SERVER_ROOT_PATH}/classifier.version $VERSION_FILE
retVal=$?
if [ $retVal -ne 0 ]; then
echo '[ERROR] Failed to get classifier version'
return $retVal
fi
# Read classifier version from file
OLD_VERSION=$(< $VERSION_FILE)
NEW_VERSION=$(($OLD_VERSION + 1))
echo "[INFO] Backing up model ${OLD_VERSION}..."
mv $MODEL_FILE "${BACKUP_DIR}/classifier${OLD_VERSION}.mlmodel"
retVal=$?
if [ $retVal -ne 0 ]; then
echo '[WARNING] Failed to back up old model'
fi
echo "[INFO] Training model ${NEW_VERSION} ..."
swift train.swift $IMAGE_DIR $TRAINING_ITERATIONS $MODEL_FILE
retVal=$?
if [ $retVal -ne 0 ]; then
echo '[ERROR] Failed to train model'
return $retVal
fi
echo "[INFO] Incrementing version file..."
echo "${NEW_VERSION}" > $VERSION_FILE
echo "[INFO] Copying the files to the server..."
scp -P $SSH_PORT $MODEL_FILE $VERSION_FILE ${SERVER}:~/
retVal=$?
if [ $retVal -ne 0 ]; then
echo '[ERROR] Failed to copy new files to server'
return $retVal
fi
echo "[INFO] Moving server files into public directory..."
ssh -p ${SSH_PORT} ${SERVER} "sudo mv /home/pi/classifier.* ${SERVER_ROOT_PATH}"
retVal=$?
if [ $retVal -ne 0 ]; then
echo '[ERROR] Failed to move files on server'
return $retVal
fi
echo "[INFO] Updating server permissions..."
ssh -p ${SSH_PORT} ${SERVER} "sudo chown -R www-data\:www-data ${SERVER_ROOT_PATH}"
retVal=$?
if [ $retVal -ne 0 ]; then
echo '[ERROR] Failed to update file permissions on server'
return $retVal
fi
echo "[INFO] Process finished"
swift train.swift $SERVER_PATH $SERVER_AUTH $IMAGE_DIR $MODEL_FILE $TRAINING_ITERATIONS

View File

@ -1,61 +1,157 @@
import Foundation
import Cocoa
import CreateML
let defaultIterations = 17
let defaultWorkingDirectory = URL(fileURLWithPath: "/Users/imac/Development/CapCollectorData")
let defaultTrainDirectory = defaultWorkingDirectory.appendingPathComponent("images")
let defaultClassifierFile = defaultWorkingDirectory.appendingPathComponent("classifier.mlmodel")
final class Server {
let server: URL
let authentication: String
init(server: URL, authentication: String) {
self.server = server
self.authentication = authentication
}
private func wait(for request: URLRequest) -> Data? {
let group = DispatchGroup()
group.enter()
var result: Data? = nil
URLSession.shared.dataTask(with: request) { data, response, _ in
defer { group.leave() }
let code = (response as! HTTPURLResponse).statusCode
guard code == 200 else {
print("[ERROR] Invalid response \(code)")
return
}
guard let data = data else {
print("[ERROR] No response data")
return
}
result = data
}.resume()
group.wait()
return result
}
func getClassifierVersion() -> Int? {
let group = DispatchGroup()
group.enter()
let classifierVersionUrl = server.appendingPathComponent("version")
guard let data = wait(for: URLRequest(url: classifierVersionUrl)) else {
return nil
}
guard let string = String(data: data, encoding: .utf8) else {
print("[ERROR] Invalid classifier version \(data)")
return nil
}
guard let int = Int(string) else {
print("[ERROR] Invalid classifier version \(string)")
return nil
}
return int
}
func readArguments() -> (images: URL, classifier: URL, iterations: Int) {
let count = CommandLine.argc
guard count > 1 else {
// No arguments
return (defaultTrainDirectory, defaultClassifierFile, defaultIterations)
func upload(classifier: Data, version: Int) -> Bool {
let classifierUrl = server
.appendingPathComponent("classifier")
.appendingPathComponent("\(version)")
var request = URLRequest(url: classifierUrl)
request.httpMethod = "POST"
request.httpBody = classifier
request.addValue(authentication, forHTTPHeaderField: "key")
return wait(for: request) != nil
}
// First argument is the image directory
let imageDir = URL(fileURLWithPath: CommandLine.arguments[1])
let classifier = imageDir.deletingLastPathComponent().appendingPathComponent("classifier.mlmodel")
guard count > 2 else {
// Single argument is the image directory
return (imageDir, classifier, defaultIterations)
func upload(classes: [Int]) -> Bool {
let classifierUrl = server
.appendingPathComponent("classes")
var request = URLRequest(url: classifierUrl)
request.httpMethod = "POST"
request.httpBody = classes.map(String.init).joined(separator: "\n").data(using: .utf8)!
request.addValue(authentication, forHTTPHeaderField: "key")
return wait(for: request) != nil
}
// Second argument is the iteration count
guard let iterations = Int(CommandLine.arguments[2]) else {
print("[ERROR] Invalid iterations argument '\(CommandLine.arguments[2])'")
exit(-1)
}
guard count > 3 else {
return (imageDir, classifier, iterations)
}
// Third argument is the classifier path
let classifierPath = URL(fileURLWithPath: CommandLine.arguments[3])
if count > 4 {
print("[WARNING] Ignoring additional arguments")
}
return (imageDir, classifierPath, iterations)
}
let arguments = readArguments()
print("[INFO] Using images in \(arguments.images.path)")
print("[INFO] Training for \(arguments.iterations) iterations")
print("[INFO] Classifier path set to \(arguments.classifier.path)")
let count = CommandLine.argc
guard count == 6 else {
print("[ERROR] Invalid number of arguments")
exit(1)
}
let serverPath = CommandLine.arguments[1]
let authenticationKey = CommandLine.arguments[2]
let imageDirectory = URL(fileURLWithPath: CommandLine.arguments[3])
let classifierUrl = URL(fileURLWithPath: CommandLine.arguments[4])
let iterationsString = CommandLine.arguments[5]
guard let serverUrl = URL(string: serverPath) else {
print("[ERROR] Invalid server path argument")
exit(1)
}
guard let iterations = Int(iterationsString) else {
print("[ERROR] Invalid iterations argument")
exit(1)
}
let server = Server(server: serverUrl, authentication: authenticationKey)
let classes: [Int]
do {
classes = try FileManager.default.contentsOfDirectory(atPath: imageDirectory.path)
.compactMap(Int.init)
} catch {
print("[ERROR] Failed to get model classes: \(error)")
exit(1)
}
guard let oldVersion = server.getClassifierVersion() else {
print("[ERROR] Failed to get classifier version")
exit(1)
}
let newVersion = oldVersion + 1
print("[INFO] Image directory: \(imageDirectory.path)")
print("[INFO] Model path: \(classifierUrl.path)")
print("[INFO] Version: \(newVersion)")
print("[INFO] Classes: \(classes.count)")
print("[INFO] Iterations: \(iterations)")
var params = MLImageClassifier.ModelParameters(augmentation: [])
params.maxIterations = arguments.iterations
params.maxIterations = iterations
let model = try MLImageClassifier(
trainingData: .labeledDirectories(at: arguments.images),
parameters: params)
let model: MLImageClassifier
do {
model = try MLImageClassifier(
trainingData: .labeledDirectories(at: imageDirectory),
parameters: params)
} catch {
print("[ERROR] Failed to create classifier: \(error)")
exit(1)
}
print("[INFO] Writing classifier...")
try model.write(to: arguments.classifier)
print("[INFO] Saving classifier...")
do {
try model.write(to: classifierUrl)
} catch {
print("[ERROR] Failed to save model to file: \(error)")
exit(1)
}
/*
let evaluation = model.evaluation(on: .labeledDirectories(at: trainDirectory))
print("Printing evaluation:")
print(evaluation)
print("Finished evaluation")
*/
print("[INFO] Uploading classifier...")
let modelData = try Data(contentsOf: classifierUrl)
guard server.upload(classifier: modelData, version: newVersion) else {
print("[ERROR] Failed to upload classifier")
exit(1)
}
print("[INFO] Uploading trained classes...")
guard server.upload(classes: classes) else {
print("[ERROR] Failed to upload classes")
exit(1)
}
print("[INFO] Done")
exit(0)