Add training script
This commit is contained in:
parent
40270d2131
commit
1289e034a7
80
Training/train.sh
Normal file
80
Training/train.sh
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
########### SUDO COMMANDS WITHOUT PASSWORD #############################################
|
||||||
|
#
|
||||||
|
# Specific commands must be executable without sudo
|
||||||
|
#
|
||||||
|
# Add to sudoers file (sudo visudo):
|
||||||
|
#
|
||||||
|
# Disable password for specific commands
|
||||||
|
# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/chmod -R 755 /data/public/capserver/images
|
||||||
|
# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/mv /home/pi/classifier.mlmodel /data/public/capserver/
|
||||||
|
# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/mv /home/pi/classifier.version /data/public/capserver/
|
||||||
|
# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/chown -R www-data\:www-data /data/public/capserver/
|
||||||
|
#
|
||||||
|
########################################################################################
|
||||||
|
|
||||||
|
########### EXPLANATION OF RSYNC FLAGS #################################################
|
||||||
|
#
|
||||||
|
# -h human readable output
|
||||||
|
# -v verbose output
|
||||||
|
# -r recursive
|
||||||
|
# -P print information about long-running transfers, keep partial files (—-progress --partial)
|
||||||
|
# -t preserves modification times
|
||||||
|
# -u only update if newer
|
||||||
|
#
|
||||||
|
########################################################################################
|
||||||
|
|
||||||
|
########### PATHS ######################################################################
|
||||||
|
|
||||||
|
WORK_DIR="/Users/imac/Projects/CapCollectorData"
|
||||||
|
BACKUP_DIR="${WORK_DIR}/Backup"
|
||||||
|
IMAGE_DIR="${WORK_DIR}/images"
|
||||||
|
VERSION_FILE="${WORK_DIR}/classifier.version"
|
||||||
|
MODEL_FILE="${WORK_DIR}/classifier.mlmodel"
|
||||||
|
|
||||||
|
TRAINING_ITERATIONS="17"
|
||||||
|
SSH_PORT="5432"
|
||||||
|
SERVER="pi@pi4"
|
||||||
|
SERVER_ROOT_PATH="/data/public/capserver"
|
||||||
|
|
||||||
|
########################################################################################
|
||||||
|
|
||||||
|
echo "[INFO] Working in directory ${WORK_DIR}"
|
||||||
|
|
||||||
|
echo "[INFO] Getting classifier version from server..."
|
||||||
|
scp -P $SSH_PORT ${SERVER}:/${SERVER_ROOT_PATH}/classifier.version $WORK_DIR
|
||||||
|
|
||||||
|
# Read classifier version from file
|
||||||
|
OLD_VERSION=$(< $VERSION_FILE)
|
||||||
|
NEW_VERSION=$(($OLD_VERSION + 1))
|
||||||
|
echo "[INFO] Creating classifier version ${NEW_VERSION}"
|
||||||
|
|
||||||
|
echo "[INFO] Ensuring permissions for images on server..."
|
||||||
|
ssh -p $SSH_PORT ${SERVER} "sudo chmod -R 755 ${SERVER_ROOT_PATH}/images"
|
||||||
|
|
||||||
|
echo "[INFO] Transferring images from server..."
|
||||||
|
rsync -hvrPut -e "ssh -p ${SSH_PORT}" ${SERVER}:/${SERVER_ROOT_PATH}/images/ "${IMAGE_DIR}"
|
||||||
|
|
||||||
|
echo "[INFO] Training the model..."
|
||||||
|
swift train.swift $IMAGE_DIR $TRAINING_ITERATIONS $MODEL_FILE
|
||||||
|
|
||||||
|
echo "[INFO] Backing up model..."
|
||||||
|
cp $MODEL_FILE "${BACKUP_DIR}/classifier${NEW_VERSION}.mlmodel"
|
||||||
|
|
||||||
|
echo "[INFO] Incrementing version file..."
|
||||||
|
echo "${NEW_VERSION}" > $VERSION_FILE
|
||||||
|
|
||||||
|
echo "[INFO] Copying the files to the server..."
|
||||||
|
scp -P $SSH_PORT $MODEL_FILE $VERSION_FILE ${SERVER}:~/
|
||||||
|
|
||||||
|
echo "[INFO] Moving files into public directory..."
|
||||||
|
ssh -p ${SSH_PORT} ${SERVER} "sudo mv /home/pi/classifier.* ${SERVER_ROOT_PATH}/"
|
||||||
|
|
||||||
|
echo "[INFO] Updating permissions..."
|
||||||
|
ssh -p ${SSH_PORT} ${SERVER} "sudo chown -R www-data\:www-data ${SERVER_ROOT_PATH}/"
|
||||||
|
|
||||||
|
echo "[INFO] Cleaning up..."
|
||||||
|
rm $MODEL_FILE $VERSION_FILE
|
||||||
|
|
||||||
|
echo "[INFO] Process finished"
|
61
Training/train.swift
Normal file
61
Training/train.swift
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import Cocoa
|
||||||
|
import CreateML
|
||||||
|
|
||||||
|
let defaultIterations = 17
|
||||||
|
let defaultWorkingDirectory = URL(fileURLWithPath: "/Users/imac/Development/CapCollectorData")
|
||||||
|
let defaultTrainDirectory = defaultWorkingDirectory.appendingPathComponent("images")
|
||||||
|
let defaultClassifierFile = defaultWorkingDirectory.appendingPathComponent("classifier.mlmodel")
|
||||||
|
|
||||||
|
func readArguments() -> (images: URL, classifier: URL, iterations: Int) {
|
||||||
|
let count = CommandLine.argc
|
||||||
|
guard count > 1 else {
|
||||||
|
// No arguments
|
||||||
|
return (defaultTrainDirectory, defaultClassifierFile, defaultIterations)
|
||||||
|
}
|
||||||
|
// First argument is the image directory
|
||||||
|
let imageDir = URL(fileURLWithPath: CommandLine.arguments[1])
|
||||||
|
let classifier = imageDir.deletingLastPathComponent().appendingPathComponent("classifier.mlmodel")
|
||||||
|
guard count > 2 else {
|
||||||
|
// Single argument is the image directory
|
||||||
|
return (imageDir, classifier, defaultIterations)
|
||||||
|
}
|
||||||
|
// Second argument is the iteration count
|
||||||
|
guard let iterations = Int(CommandLine.arguments[2]) else {
|
||||||
|
print("[ERROR] Invalid iterations argument '\(CommandLine.arguments[2])'")
|
||||||
|
exit(-1)
|
||||||
|
}
|
||||||
|
guard count > 3 else {
|
||||||
|
return (imageDir, classifier, iterations)
|
||||||
|
}
|
||||||
|
// Third argument is the classifier path
|
||||||
|
let classifierPath = URL(fileURLWithPath: CommandLine.arguments[3])
|
||||||
|
if count > 4 {
|
||||||
|
print("[WARNING] Ignoring additional arguments")
|
||||||
|
}
|
||||||
|
return (imageDir, classifierPath, iterations)
|
||||||
|
}
|
||||||
|
|
||||||
|
let arguments = readArguments()
|
||||||
|
|
||||||
|
print("[INFO] Using images in \(arguments.images.path)")
|
||||||
|
print("[INFO] Training for \(arguments.iterations) iterations")
|
||||||
|
print("[INFO] Classifier path set to \(arguments.classifier.path)")
|
||||||
|
|
||||||
|
var params = MLImageClassifier.ModelParameters(augmentation: [])
|
||||||
|
params.maxIterations = arguments.iterations
|
||||||
|
|
||||||
|
let model = try MLImageClassifier(
|
||||||
|
trainingData: .labeledDirectories(at: arguments.images),
|
||||||
|
parameters: params)
|
||||||
|
|
||||||
|
print("[INFO] Writing classifier...")
|
||||||
|
try model.write(to: arguments.classifier)
|
||||||
|
|
||||||
|
/*
|
||||||
|
let evaluation = model.evaluation(on: .labeledDirectories(at: trainDirectory))
|
||||||
|
print("Printing evaluation:")
|
||||||
|
print(evaluation)
|
||||||
|
print("Finished evaluation")
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user