Restructure training script, use new API
This commit is contained in:
parent
631b93872e
commit
56b74117ca
12
Training/config.sh
Normal file
12
Training/config.sh
Normal file
@ -0,0 +1,12 @@
|
||||
WORK_DIR="."
|
||||
BACKUP_DIR="./backup"
|
||||
IMAGE_DIR="../Public/images"
|
||||
MODEL_FILE="../Public/classifier.mlmodel"
|
||||
|
||||
TRAINING_ITERATIONS="20"
|
||||
SSH_PORT="22"
|
||||
SERVER="pi@mydomain.com"
|
||||
SERVER_ROOT_PATH="/caps/Public"
|
||||
SERVER_PATH="https://mydomain.com/caps"
|
||||
SERVER_AUTH="mysecretkey"
|
||||
|
@ -25,20 +25,7 @@
|
||||
#
|
||||
########################################################################################
|
||||
|
||||
########### PATHS ######################################################################
|
||||
|
||||
WORK_DIR="${HOME}/Projects/Caps/Caps-Server/Training"
|
||||
BACKUP_DIR="./backup"
|
||||
IMAGE_DIR="../Public/images"
|
||||
VERSION_FILE="../Public/classifier.version"
|
||||
MODEL_FILE="../Public/classifier.mlmodel"
|
||||
|
||||
TRAINING_ITERATIONS="17"
|
||||
SSH_PORT="5432"
|
||||
SERVER="pi@christophhagen.de"
|
||||
SERVER_ROOT_PATH="/data/servers/caps/Public"
|
||||
|
||||
########################################################################################
|
||||
source config.sh
|
||||
|
||||
echo "[INFO] Working in directory ${WORK_DIR}"
|
||||
cd $WORK_DIR
|
||||
@ -67,64 +54,4 @@ if [ $retVal -ne 0 ]; then
|
||||
return $retVal
|
||||
fi
|
||||
|
||||
echo "[INFO] Getting classifier version from server..."
|
||||
scp -P $SSH_PORT ${SERVER}:/${SERVER_ROOT_PATH}/classifier.version $VERSION_FILE
|
||||
|
||||
retVal=$?
|
||||
if [ $retVal -ne 0 ]; then
|
||||
echo '[ERROR] Failed to get classifier version'
|
||||
return $retVal
|
||||
fi
|
||||
|
||||
# Read classifier version from file
|
||||
OLD_VERSION=$(< $VERSION_FILE)
|
||||
NEW_VERSION=$(($OLD_VERSION + 1))
|
||||
|
||||
echo "[INFO] Backing up model ${OLD_VERSION}..."
|
||||
mv $MODEL_FILE "${BACKUP_DIR}/classifier${OLD_VERSION}.mlmodel"
|
||||
|
||||
retVal=$?
|
||||
if [ $retVal -ne 0 ]; then
|
||||
echo '[WARNING] Failed to back up old model'
|
||||
fi
|
||||
|
||||
echo "[INFO] Training model ${NEW_VERSION} ..."
|
||||
swift train.swift $IMAGE_DIR $TRAINING_ITERATIONS $MODEL_FILE
|
||||
|
||||
retVal=$?
|
||||
if [ $retVal -ne 0 ]; then
|
||||
echo '[ERROR] Failed to train model'
|
||||
return $retVal
|
||||
fi
|
||||
|
||||
echo "[INFO] Incrementing version file..."
|
||||
echo "${NEW_VERSION}" > $VERSION_FILE
|
||||
|
||||
echo "[INFO] Copying the files to the server..."
|
||||
scp -P $SSH_PORT $MODEL_FILE $VERSION_FILE ${SERVER}:~/
|
||||
|
||||
retVal=$?
|
||||
if [ $retVal -ne 0 ]; then
|
||||
echo '[ERROR] Failed to copy new files to server'
|
||||
return $retVal
|
||||
fi
|
||||
|
||||
echo "[INFO] Moving server files into public directory..."
|
||||
ssh -p ${SSH_PORT} ${SERVER} "sudo mv /home/pi/classifier.* ${SERVER_ROOT_PATH}"
|
||||
|
||||
retVal=$?
|
||||
if [ $retVal -ne 0 ]; then
|
||||
echo '[ERROR] Failed to move files on server'
|
||||
return $retVal
|
||||
fi
|
||||
|
||||
echo "[INFO] Updating server permissions..."
|
||||
ssh -p ${SSH_PORT} ${SERVER} "sudo chown -R www-data\:www-data ${SERVER_ROOT_PATH}"
|
||||
|
||||
retVal=$?
|
||||
if [ $retVal -ne 0 ]; then
|
||||
echo '[ERROR] Failed to update file permissions on server'
|
||||
return $retVal
|
||||
fi
|
||||
|
||||
echo "[INFO] Process finished"
|
||||
swift train.swift $SERVER_PATH $SERVER_AUTH $IMAGE_DIR $MODEL_FILE $TRAINING_ITERATIONS
|
||||
|
@ -1,61 +1,157 @@
|
||||
import Foundation
|
||||
import Cocoa
|
||||
import CreateML
|
||||
|
||||
let defaultIterations = 17
|
||||
let defaultWorkingDirectory = URL(fileURLWithPath: "/Users/imac/Development/CapCollectorData")
|
||||
let defaultTrainDirectory = defaultWorkingDirectory.appendingPathComponent("images")
|
||||
let defaultClassifierFile = defaultWorkingDirectory.appendingPathComponent("classifier.mlmodel")
|
||||
final class Server {
|
||||
|
||||
func readArguments() -> (images: URL, classifier: URL, iterations: Int) {
|
||||
let count = CommandLine.argc
|
||||
guard count > 1 else {
|
||||
// No arguments
|
||||
return (defaultTrainDirectory, defaultClassifierFile, defaultIterations)
|
||||
let server: URL
|
||||
|
||||
let authentication: String
|
||||
|
||||
init(server: URL, authentication: String) {
|
||||
self.server = server
|
||||
self.authentication = authentication
|
||||
}
|
||||
// First argument is the image directory
|
||||
let imageDir = URL(fileURLWithPath: CommandLine.arguments[1])
|
||||
let classifier = imageDir.deletingLastPathComponent().appendingPathComponent("classifier.mlmodel")
|
||||
guard count > 2 else {
|
||||
// Single argument is the image directory
|
||||
return (imageDir, classifier, defaultIterations)
|
||||
|
||||
private func wait(for request: URLRequest) -> Data? {
|
||||
let group = DispatchGroup()
|
||||
group.enter()
|
||||
var result: Data? = nil
|
||||
URLSession.shared.dataTask(with: request) { data, response, _ in
|
||||
defer { group.leave() }
|
||||
let code = (response as! HTTPURLResponse).statusCode
|
||||
guard code == 200 else {
|
||||
print("[ERROR] Invalid response \(code)")
|
||||
return
|
||||
}
|
||||
// Second argument is the iteration count
|
||||
guard let iterations = Int(CommandLine.arguments[2]) else {
|
||||
print("[ERROR] Invalid iterations argument '\(CommandLine.arguments[2])'")
|
||||
exit(-1)
|
||||
guard let data = data else {
|
||||
print("[ERROR] No response data")
|
||||
return
|
||||
}
|
||||
guard count > 3 else {
|
||||
return (imageDir, classifier, iterations)
|
||||
result = data
|
||||
}.resume()
|
||||
group.wait()
|
||||
return result
|
||||
}
|
||||
// Third argument is the classifier path
|
||||
let classifierPath = URL(fileURLWithPath: CommandLine.arguments[3])
|
||||
if count > 4 {
|
||||
print("[WARNING] Ignoring additional arguments")
|
||||
|
||||
func getClassifierVersion() -> Int? {
|
||||
let group = DispatchGroup()
|
||||
group.enter()
|
||||
let classifierVersionUrl = server.appendingPathComponent("version")
|
||||
guard let data = wait(for: URLRequest(url: classifierVersionUrl)) else {
|
||||
return nil
|
||||
}
|
||||
guard let string = String(data: data, encoding: .utf8) else {
|
||||
print("[ERROR] Invalid classifier version \(data)")
|
||||
return nil
|
||||
}
|
||||
guard let int = Int(string) else {
|
||||
print("[ERROR] Invalid classifier version \(string)")
|
||||
return nil
|
||||
}
|
||||
return int
|
||||
}
|
||||
|
||||
func upload(classifier: Data, version: Int) -> Bool {
|
||||
let classifierUrl = server
|
||||
.appendingPathComponent("classifier")
|
||||
.appendingPathComponent("\(version)")
|
||||
var request = URLRequest(url: classifierUrl)
|
||||
request.httpMethod = "POST"
|
||||
request.httpBody = classifier
|
||||
request.addValue(authentication, forHTTPHeaderField: "key")
|
||||
return wait(for: request) != nil
|
||||
}
|
||||
|
||||
func upload(classes: [Int]) -> Bool {
|
||||
let classifierUrl = server
|
||||
.appendingPathComponent("classes")
|
||||
var request = URLRequest(url: classifierUrl)
|
||||
request.httpMethod = "POST"
|
||||
request.httpBody = classes.map(String.init).joined(separator: "\n").data(using: .utf8)!
|
||||
request.addValue(authentication, forHTTPHeaderField: "key")
|
||||
return wait(for: request) != nil
|
||||
}
|
||||
return (imageDir, classifierPath, iterations)
|
||||
}
|
||||
|
||||
let arguments = readArguments()
|
||||
|
||||
print("[INFO] Using images in \(arguments.images.path)")
|
||||
print("[INFO] Training for \(arguments.iterations) iterations")
|
||||
print("[INFO] Classifier path set to \(arguments.classifier.path)")
|
||||
let count = CommandLine.argc
|
||||
guard count == 6 else {
|
||||
print("[ERROR] Invalid number of arguments")
|
||||
exit(1)
|
||||
}
|
||||
let serverPath = CommandLine.arguments[1]
|
||||
let authenticationKey = CommandLine.arguments[2]
|
||||
let imageDirectory = URL(fileURLWithPath: CommandLine.arguments[3])
|
||||
let classifierUrl = URL(fileURLWithPath: CommandLine.arguments[4])
|
||||
let iterationsString = CommandLine.arguments[5]
|
||||
|
||||
guard let serverUrl = URL(string: serverPath) else {
|
||||
print("[ERROR] Invalid server path argument")
|
||||
exit(1)
|
||||
}
|
||||
guard let iterations = Int(iterationsString) else {
|
||||
print("[ERROR] Invalid iterations argument")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
let server = Server(server: serverUrl, authentication: authenticationKey)
|
||||
|
||||
let classes: [Int]
|
||||
|
||||
do {
|
||||
classes = try FileManager.default.contentsOfDirectory(atPath: imageDirectory.path)
|
||||
.compactMap(Int.init)
|
||||
} catch {
|
||||
print("[ERROR] Failed to get model classes: \(error)")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
guard let oldVersion = server.getClassifierVersion() else {
|
||||
print("[ERROR] Failed to get classifier version")
|
||||
exit(1)
|
||||
}
|
||||
let newVersion = oldVersion + 1
|
||||
|
||||
print("[INFO] Image directory: \(imageDirectory.path)")
|
||||
print("[INFO] Model path: \(classifierUrl.path)")
|
||||
print("[INFO] Version: \(newVersion)")
|
||||
print("[INFO] Classes: \(classes.count)")
|
||||
print("[INFO] Iterations: \(iterations)")
|
||||
|
||||
var params = MLImageClassifier.ModelParameters(augmentation: [])
|
||||
params.maxIterations = arguments.iterations
|
||||
params.maxIterations = iterations
|
||||
|
||||
let model = try MLImageClassifier(
|
||||
trainingData: .labeledDirectories(at: arguments.images),
|
||||
let model: MLImageClassifier
|
||||
do {
|
||||
model = try MLImageClassifier(
|
||||
trainingData: .labeledDirectories(at: imageDirectory),
|
||||
parameters: params)
|
||||
} catch {
|
||||
print("[ERROR] Failed to create classifier: \(error)")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
print("[INFO] Writing classifier...")
|
||||
try model.write(to: arguments.classifier)
|
||||
print("[INFO] Saving classifier...")
|
||||
do {
|
||||
try model.write(to: classifierUrl)
|
||||
} catch {
|
||||
print("[ERROR] Failed to save model to file: \(error)")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
/*
|
||||
let evaluation = model.evaluation(on: .labeledDirectories(at: trainDirectory))
|
||||
print("Printing evaluation:")
|
||||
print(evaluation)
|
||||
print("Finished evaluation")
|
||||
*/
|
||||
print("[INFO] Uploading classifier...")
|
||||
let modelData = try Data(contentsOf: classifierUrl)
|
||||
guard server.upload(classifier: modelData, version: newVersion) else {
|
||||
print("[ERROR] Failed to upload classifier")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
print("[INFO] Uploading trained classes...")
|
||||
guard server.upload(classes: classes) else {
|
||||
print("[ERROR] Failed to upload classes")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
print("[INFO] Done")
|
||||
exit(0)
|
||||
|
Loading…
Reference in New Issue
Block a user