Restructure training script, use new API
This commit is contained in:
parent
631b93872e
commit
56b74117ca
12
Training/config.sh
Normal file
12
Training/config.sh
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
WORK_DIR="."
|
||||||
|
BACKUP_DIR="./backup"
|
||||||
|
IMAGE_DIR="../Public/images"
|
||||||
|
MODEL_FILE="../Public/classifier.mlmodel"
|
||||||
|
|
||||||
|
TRAINING_ITERATIONS="20"
|
||||||
|
SSH_PORT="22"
|
||||||
|
SERVER="pi@mydomain.com"
|
||||||
|
SERVER_ROOT_PATH="/caps/Public"
|
||||||
|
SERVER_PATH="https://mydomain.com/caps"
|
||||||
|
SERVER_AUTH="mysecretkey"
|
||||||
|
|
@ -25,20 +25,7 @@
|
|||||||
#
|
#
|
||||||
########################################################################################
|
########################################################################################
|
||||||
|
|
||||||
########### PATHS ######################################################################
|
source config.sh
|
||||||
|
|
||||||
WORK_DIR="${HOME}/Projects/Caps/Caps-Server/Training"
|
|
||||||
BACKUP_DIR="./backup"
|
|
||||||
IMAGE_DIR="../Public/images"
|
|
||||||
VERSION_FILE="../Public/classifier.version"
|
|
||||||
MODEL_FILE="../Public/classifier.mlmodel"
|
|
||||||
|
|
||||||
TRAINING_ITERATIONS="17"
|
|
||||||
SSH_PORT="5432"
|
|
||||||
SERVER="pi@christophhagen.de"
|
|
||||||
SERVER_ROOT_PATH="/data/servers/caps/Public"
|
|
||||||
|
|
||||||
########################################################################################
|
|
||||||
|
|
||||||
echo "[INFO] Working in directory ${WORK_DIR}"
|
echo "[INFO] Working in directory ${WORK_DIR}"
|
||||||
cd $WORK_DIR
|
cd $WORK_DIR
|
||||||
@ -67,64 +54,4 @@ if [ $retVal -ne 0 ]; then
|
|||||||
return $retVal
|
return $retVal
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "[INFO] Getting classifier version from server..."
|
swift train.swift $SERVER_PATH $SERVER_AUTH $IMAGE_DIR $MODEL_FILE $TRAINING_ITERATIONS
|
||||||
scp -P $SSH_PORT ${SERVER}:/${SERVER_ROOT_PATH}/classifier.version $VERSION_FILE
|
|
||||||
|
|
||||||
retVal=$?
|
|
||||||
if [ $retVal -ne 0 ]; then
|
|
||||||
echo '[ERROR] Failed to get classifier version'
|
|
||||||
return $retVal
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Read classifier version from file
|
|
||||||
OLD_VERSION=$(< $VERSION_FILE)
|
|
||||||
NEW_VERSION=$(($OLD_VERSION + 1))
|
|
||||||
|
|
||||||
echo "[INFO] Backing up model ${OLD_VERSION}..."
|
|
||||||
mv $MODEL_FILE "${BACKUP_DIR}/classifier${OLD_VERSION}.mlmodel"
|
|
||||||
|
|
||||||
retVal=$?
|
|
||||||
if [ $retVal -ne 0 ]; then
|
|
||||||
echo '[WARNING] Failed to back up old model'
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "[INFO] Training model ${NEW_VERSION} ..."
|
|
||||||
swift train.swift $IMAGE_DIR $TRAINING_ITERATIONS $MODEL_FILE
|
|
||||||
|
|
||||||
retVal=$?
|
|
||||||
if [ $retVal -ne 0 ]; then
|
|
||||||
echo '[ERROR] Failed to train model'
|
|
||||||
return $retVal
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "[INFO] Incrementing version file..."
|
|
||||||
echo "${NEW_VERSION}" > $VERSION_FILE
|
|
||||||
|
|
||||||
echo "[INFO] Copying the files to the server..."
|
|
||||||
scp -P $SSH_PORT $MODEL_FILE $VERSION_FILE ${SERVER}:~/
|
|
||||||
|
|
||||||
retVal=$?
|
|
||||||
if [ $retVal -ne 0 ]; then
|
|
||||||
echo '[ERROR] Failed to copy new files to server'
|
|
||||||
return $retVal
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "[INFO] Moving server files into public directory..."
|
|
||||||
ssh -p ${SSH_PORT} ${SERVER} "sudo mv /home/pi/classifier.* ${SERVER_ROOT_PATH}"
|
|
||||||
|
|
||||||
retVal=$?
|
|
||||||
if [ $retVal -ne 0 ]; then
|
|
||||||
echo '[ERROR] Failed to move files on server'
|
|
||||||
return $retVal
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "[INFO] Updating server permissions..."
|
|
||||||
ssh -p ${SSH_PORT} ${SERVER} "sudo chown -R www-data\:www-data ${SERVER_ROOT_PATH}"
|
|
||||||
|
|
||||||
retVal=$?
|
|
||||||
if [ $retVal -ne 0 ]; then
|
|
||||||
echo '[ERROR] Failed to update file permissions on server'
|
|
||||||
return $retVal
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "[INFO] Process finished"
|
|
||||||
|
@ -1,61 +1,157 @@
|
|||||||
|
import Foundation
|
||||||
import Cocoa
|
import Cocoa
|
||||||
import CreateML
|
import CreateML
|
||||||
|
|
||||||
let defaultIterations = 17
|
final class Server {
|
||||||
let defaultWorkingDirectory = URL(fileURLWithPath: "/Users/imac/Development/CapCollectorData")
|
|
||||||
let defaultTrainDirectory = defaultWorkingDirectory.appendingPathComponent("images")
|
let server: URL
|
||||||
let defaultClassifierFile = defaultWorkingDirectory.appendingPathComponent("classifier.mlmodel")
|
|
||||||
|
let authentication: String
|
||||||
|
|
||||||
|
init(server: URL, authentication: String) {
|
||||||
|
self.server = server
|
||||||
|
self.authentication = authentication
|
||||||
|
}
|
||||||
|
|
||||||
|
private func wait(for request: URLRequest) -> Data? {
|
||||||
|
let group = DispatchGroup()
|
||||||
|
group.enter()
|
||||||
|
var result: Data? = nil
|
||||||
|
URLSession.shared.dataTask(with: request) { data, response, _ in
|
||||||
|
defer { group.leave() }
|
||||||
|
let code = (response as! HTTPURLResponse).statusCode
|
||||||
|
guard code == 200 else {
|
||||||
|
print("[ERROR] Invalid response \(code)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
guard let data = data else {
|
||||||
|
print("[ERROR] No response data")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result = data
|
||||||
|
}.resume()
|
||||||
|
group.wait()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func getClassifierVersion() -> Int? {
|
||||||
|
let group = DispatchGroup()
|
||||||
|
group.enter()
|
||||||
|
let classifierVersionUrl = server.appendingPathComponent("version")
|
||||||
|
guard let data = wait(for: URLRequest(url: classifierVersionUrl)) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
guard let string = String(data: data, encoding: .utf8) else {
|
||||||
|
print("[ERROR] Invalid classifier version \(data)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
guard let int = Int(string) else {
|
||||||
|
print("[ERROR] Invalid classifier version \(string)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return int
|
||||||
|
}
|
||||||
|
|
||||||
func readArguments() -> (images: URL, classifier: URL, iterations: Int) {
|
func upload(classifier: Data, version: Int) -> Bool {
|
||||||
let count = CommandLine.argc
|
let classifierUrl = server
|
||||||
guard count > 1 else {
|
.appendingPathComponent("classifier")
|
||||||
// No arguments
|
.appendingPathComponent("\(version)")
|
||||||
return (defaultTrainDirectory, defaultClassifierFile, defaultIterations)
|
var request = URLRequest(url: classifierUrl)
|
||||||
|
request.httpMethod = "POST"
|
||||||
|
request.httpBody = classifier
|
||||||
|
request.addValue(authentication, forHTTPHeaderField: "key")
|
||||||
|
return wait(for: request) != nil
|
||||||
}
|
}
|
||||||
// First argument is the image directory
|
|
||||||
let imageDir = URL(fileURLWithPath: CommandLine.arguments[1])
|
func upload(classes: [Int]) -> Bool {
|
||||||
let classifier = imageDir.deletingLastPathComponent().appendingPathComponent("classifier.mlmodel")
|
let classifierUrl = server
|
||||||
guard count > 2 else {
|
.appendingPathComponent("classes")
|
||||||
// Single argument is the image directory
|
var request = URLRequest(url: classifierUrl)
|
||||||
return (imageDir, classifier, defaultIterations)
|
request.httpMethod = "POST"
|
||||||
|
request.httpBody = classes.map(String.init).joined(separator: "\n").data(using: .utf8)!
|
||||||
|
request.addValue(authentication, forHTTPHeaderField: "key")
|
||||||
|
return wait(for: request) != nil
|
||||||
}
|
}
|
||||||
// Second argument is the iteration count
|
|
||||||
guard let iterations = Int(CommandLine.arguments[2]) else {
|
|
||||||
print("[ERROR] Invalid iterations argument '\(CommandLine.arguments[2])'")
|
|
||||||
exit(-1)
|
|
||||||
}
|
|
||||||
guard count > 3 else {
|
|
||||||
return (imageDir, classifier, iterations)
|
|
||||||
}
|
|
||||||
// Third argument is the classifier path
|
|
||||||
let classifierPath = URL(fileURLWithPath: CommandLine.arguments[3])
|
|
||||||
if count > 4 {
|
|
||||||
print("[WARNING] Ignoring additional arguments")
|
|
||||||
}
|
|
||||||
return (imageDir, classifierPath, iterations)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let arguments = readArguments()
|
|
||||||
|
|
||||||
print("[INFO] Using images in \(arguments.images.path)")
|
let count = CommandLine.argc
|
||||||
print("[INFO] Training for \(arguments.iterations) iterations")
|
guard count == 6 else {
|
||||||
print("[INFO] Classifier path set to \(arguments.classifier.path)")
|
print("[ERROR] Invalid number of arguments")
|
||||||
|
exit(1)
|
||||||
|
}
|
||||||
|
let serverPath = CommandLine.arguments[1]
|
||||||
|
let authenticationKey = CommandLine.arguments[2]
|
||||||
|
let imageDirectory = URL(fileURLWithPath: CommandLine.arguments[3])
|
||||||
|
let classifierUrl = URL(fileURLWithPath: CommandLine.arguments[4])
|
||||||
|
let iterationsString = CommandLine.arguments[5]
|
||||||
|
|
||||||
|
guard let serverUrl = URL(string: serverPath) else {
|
||||||
|
print("[ERROR] Invalid server path argument")
|
||||||
|
exit(1)
|
||||||
|
}
|
||||||
|
guard let iterations = Int(iterationsString) else {
|
||||||
|
print("[ERROR] Invalid iterations argument")
|
||||||
|
exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
let server = Server(server: serverUrl, authentication: authenticationKey)
|
||||||
|
|
||||||
|
let classes: [Int]
|
||||||
|
|
||||||
|
do {
|
||||||
|
classes = try FileManager.default.contentsOfDirectory(atPath: imageDirectory.path)
|
||||||
|
.compactMap(Int.init)
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to get model classes: \(error)")
|
||||||
|
exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
guard let oldVersion = server.getClassifierVersion() else {
|
||||||
|
print("[ERROR] Failed to get classifier version")
|
||||||
|
exit(1)
|
||||||
|
}
|
||||||
|
let newVersion = oldVersion + 1
|
||||||
|
|
||||||
|
print("[INFO] Image directory: \(imageDirectory.path)")
|
||||||
|
print("[INFO] Model path: \(classifierUrl.path)")
|
||||||
|
print("[INFO] Version: \(newVersion)")
|
||||||
|
print("[INFO] Classes: \(classes.count)")
|
||||||
|
print("[INFO] Iterations: \(iterations)")
|
||||||
|
|
||||||
var params = MLImageClassifier.ModelParameters(augmentation: [])
|
var params = MLImageClassifier.ModelParameters(augmentation: [])
|
||||||
params.maxIterations = arguments.iterations
|
params.maxIterations = iterations
|
||||||
|
|
||||||
let model = try MLImageClassifier(
|
let model: MLImageClassifier
|
||||||
trainingData: .labeledDirectories(at: arguments.images),
|
do {
|
||||||
parameters: params)
|
model = try MLImageClassifier(
|
||||||
|
trainingData: .labeledDirectories(at: imageDirectory),
|
||||||
|
parameters: params)
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to create classifier: \(error)")
|
||||||
|
exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
print("[INFO] Writing classifier...")
|
print("[INFO] Saving classifier...")
|
||||||
try model.write(to: arguments.classifier)
|
do {
|
||||||
|
try model.write(to: classifierUrl)
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to save model to file: \(error)")
|
||||||
|
exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
print("[INFO] Uploading classifier...")
|
||||||
let evaluation = model.evaluation(on: .labeledDirectories(at: trainDirectory))
|
let modelData = try Data(contentsOf: classifierUrl)
|
||||||
print("Printing evaluation:")
|
guard server.upload(classifier: modelData, version: newVersion) else {
|
||||||
print(evaluation)
|
print("[ERROR] Failed to upload classifier")
|
||||||
print("Finished evaluation")
|
exit(1)
|
||||||
*/
|
}
|
||||||
|
|
||||||
|
print("[INFO] Uploading trained classes...")
|
||||||
|
guard server.upload(classes: classes) else {
|
||||||
|
print("[ERROR] Failed to upload classes")
|
||||||
|
exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
print("[INFO] Done")
|
||||||
|
exit(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user