Switch to pure Swift training script
This commit is contained in:
parent
299dba0da4
commit
d516c1acd6
2
.gitignore
vendored
2
.gitignore
vendored
@ -13,5 +13,5 @@ Public/classifier.version
|
||||
Public/classifier.mlmodel
|
||||
Public/caps.json
|
||||
Training/backup/
|
||||
Training/config.sh
|
||||
Training/config.json
|
||||
Public/thumbnails
|
||||
|
@ -22,6 +22,8 @@ final class CapServer: ServerOwner {
|
||||
|
||||
private let fm = FileManager.default
|
||||
|
||||
private let changedImageEntryDateFormatter: DateFormatter
|
||||
|
||||
// MARK: Caps
|
||||
|
||||
private var writers: Set<String>
|
||||
@ -77,6 +79,8 @@ final class CapServer: ServerOwner {
|
||||
self.classifierFile = folder.appendingPathComponent("classifier.mlmodel")
|
||||
self.changedImagesFile = folder.appendingPathComponent("changes.txt")
|
||||
self.writers = Set(writers)
|
||||
self.changedImageEntryDateFormatter = DateFormatter()
|
||||
changedImageEntryDateFormatter.dateFormat = "yy-MM-dd-HH-mm-ss"
|
||||
}
|
||||
|
||||
func loadData() throws {
|
||||
@ -303,9 +307,7 @@ final class CapServer: ServerOwner {
|
||||
unwrittenImageChanges = entries
|
||||
try? handle.close()
|
||||
}
|
||||
let df = DateFormatter()
|
||||
df.dateFormat = "yy-MM-dd-HH-mm-ss"
|
||||
let dateString = df.string(from: Date())
|
||||
let dateString = changedImageEntryDateFormatter.string(from: Date())
|
||||
while let entry = entries.popLast() {
|
||||
let content = "\(dateString):\(entry.cap):\(entry.image)\n".data(using: .utf8)!
|
||||
try handle.write(contentsOf: content)
|
||||
@ -333,11 +335,25 @@ final class CapServer: ServerOwner {
|
||||
}
|
||||
}
|
||||
|
||||
func emptyChangedImageListFile() {
|
||||
func removeAllEntriesInImageChangeList(before date: Date) {
|
||||
do {
|
||||
try Data().write(to: changedImagesFile)
|
||||
try String(contentsOf: changedImagesFile)
|
||||
.components(separatedBy: "\n")
|
||||
.filter { $0 != "" }
|
||||
.compactMap { line -> String? in
|
||||
guard let entryDate = changedImageEntryDateFormatter.date(from: line.components(separatedBy: ":").first!) else {
|
||||
return nil
|
||||
}
|
||||
guard entryDate > date else {
|
||||
return nil
|
||||
}
|
||||
return line
|
||||
}
|
||||
.joined(separator: "\n")
|
||||
.data(using: .utf8)!
|
||||
.write(to: changedImagesFile)
|
||||
} catch {
|
||||
log("Failed to empty changed images file: \(error)")
|
||||
log("Failed to update changed images file: \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4,6 +4,13 @@ import Foundation
|
||||
/// The decoder to extract caps from JSON payloads given to the `cap` route.
|
||||
private let decoder = JSONDecoder()
|
||||
|
||||
/// The date formatter to decode dates in requests
|
||||
private let dateFormatter: DateFormatter = {
|
||||
let df = DateFormatter()
|
||||
df.dateFormat = "yy-MM-dd-HH-mm-ss"
|
||||
return df
|
||||
}()
|
||||
|
||||
private func authorize(_ request: Request) throws {
|
||||
guard let key = request.headers.first(name: "key") else {
|
||||
throw Abort(.badRequest) // 400
|
||||
@ -65,7 +72,16 @@ func routes(_ app: Application) {
|
||||
}
|
||||
|
||||
// Update the trained classes
|
||||
app.postCatching("classes") { request in
|
||||
app.postCatching("classes", ":date") { request in
|
||||
guard let dateString = request.parameters.get("date") else {
|
||||
log("Invalid parameter for date")
|
||||
throw Abort(.badRequest)
|
||||
}
|
||||
guard let date = dateFormatter.date(from: dateString) else {
|
||||
log("Invalid date specification")
|
||||
throw Abort(.badRequest)
|
||||
}
|
||||
|
||||
try authorize(request)
|
||||
guard let buffer = request.body.data else {
|
||||
log("Missing body data: \(request.body.description)")
|
||||
@ -77,6 +93,6 @@ func routes(_ app: Application) {
|
||||
throw CapError.invalidBody
|
||||
}
|
||||
server.updateTrainedClasses(content: content)
|
||||
server.emptyChangedImageListFile()
|
||||
server.removeAllEntriesInImageChangeList(before: date)
|
||||
}
|
||||
}
|
||||
|
7
Training/config_example.json
Normal file
7
Training/config_example.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"imageDirectory": "../Public/images",
|
||||
"classifierModelPath": "../Public/classifier.mlmodel",
|
||||
"trainingIterations": 20,
|
||||
"serverPath": "https://mydomain.com/caps",
|
||||
"authenticationToken": "mysecretkey",
|
||||
}
|
@ -1,10 +0,0 @@
|
||||
IMAGE_DIR="../Public/images"
|
||||
MODEL_FILE="../Public/classifier.mlmodel"
|
||||
|
||||
TRAINING_ITERATIONS="20"
|
||||
SSH_PORT="22"
|
||||
SERVER="pi@mydomain.com"
|
||||
SERVER_ROOT_PATH="/caps/Public"
|
||||
SERVER_PATH="https://mydomain.com/caps"
|
||||
SERVER_AUTH="mysecretkey"
|
||||
|
@ -1,48 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
########### SUDO COMMANDS WITHOUT PASSWORD #############################################
|
||||
#
|
||||
# Specific commands must be executable without sudo
|
||||
#
|
||||
# Add to sudoers file (sudo visudo):
|
||||
#
|
||||
# Disable password for specific commands
|
||||
# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/chmod -R 755 /data/public/capserver/images
|
||||
# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/mv /home/pi/classifier.mlmodel /data/public/capserver/
|
||||
# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/mv /home/pi/classifier.version /data/public/capserver/
|
||||
# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/chown -R www-data\:www-data /data/public/capserver/
|
||||
#
|
||||
########################################################################################
|
||||
|
||||
########### EXPLANATION OF RSYNC FLAGS #################################################
|
||||
#
|
||||
# -h human readable output
|
||||
# -v verbose output
|
||||
# -r recursive
|
||||
# -P print information about long-running transfers, keep partial files (—-progress --partial)
|
||||
# -t preserves modification times
|
||||
# -u only update if newer
|
||||
#
|
||||
########################################################################################
|
||||
|
||||
source config.sh
|
||||
|
||||
echo "[INFO] Ensuring permissions for images on server..."
|
||||
ssh -p $SSH_PORT ${SERVER} "sudo chmod -R 755 ${SERVER_ROOT_PATH}/images"
|
||||
|
||||
retVal=$?
|
||||
if [ $retVal -ne 0 ]; then
|
||||
echo '[ERROR] Failed to change image permissions'
|
||||
return $retVal
|
||||
fi
|
||||
|
||||
echo "[INFO] Transferring images from server..."
|
||||
rsync -hrut --info=progress2 -e "ssh -p ${SSH_PORT}" ${SERVER}:/${SERVER_ROOT_PATH}/images/ "${IMAGE_DIR}"
|
||||
|
||||
retVal=$?
|
||||
if [ $retVal -ne 0 ]; then
|
||||
echo '[ERROR] Failed to transfer images from server'
|
||||
return $retVal
|
||||
fi
|
||||
|
||||
swift train.swift $SERVER_PATH $SERVER_AUTH $IMAGE_DIR $MODEL_FILE $TRAINING_ITERATIONS
|
@ -2,156 +2,378 @@ import Foundation
|
||||
import Cocoa
|
||||
import CreateML
|
||||
|
||||
final class Server {
|
||||
struct Configuration: Codable {
|
||||
|
||||
let imageDirectory: String
|
||||
|
||||
let classifierModelPath: String
|
||||
|
||||
let trainingIterations: Int
|
||||
|
||||
let serverPath: String
|
||||
|
||||
let authenticationToken: String
|
||||
|
||||
init?(at url: URL) {
|
||||
do {
|
||||
let configData = try Data(contentsOf: url)
|
||||
self = try JSONDecoder().decode(Configuration.self, from: configData)
|
||||
} catch {
|
||||
print("[ERROR] Failed to load configuration at \(url.absoluteURL.path): \(error)")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Cap: Codable {
|
||||
|
||||
let id: Int
|
||||
|
||||
let count: Int
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case id = "i"
|
||||
case count = "c"
|
||||
}
|
||||
}
|
||||
|
||||
final class ClassifierCreator {
|
||||
|
||||
static let configurationFileUrl = URL(fileURLWithPath: "config.json")
|
||||
|
||||
let server: URL
|
||||
|
||||
let authentication: String
|
||||
let configuration: Configuration
|
||||
|
||||
init(server: URL, authentication: String) {
|
||||
self.server = server
|
||||
self.authentication = authentication
|
||||
}
|
||||
let imageDirectory: URL
|
||||
|
||||
private func wait(for request: URLRequest) -> Data? {
|
||||
let group = DispatchGroup()
|
||||
group.enter()
|
||||
var result: Data? = nil
|
||||
URLSession.shared.dataTask(with: request) { data, response, _ in
|
||||
defer { group.leave() }
|
||||
let code = (response as! HTTPURLResponse).statusCode
|
||||
guard code == 200 else {
|
||||
print("[ERROR] Invalid response \(code)")
|
||||
return
|
||||
}
|
||||
guard let data = data else {
|
||||
print("[ERROR] No response data")
|
||||
return
|
||||
}
|
||||
result = data
|
||||
}.resume()
|
||||
group.wait()
|
||||
return result
|
||||
}
|
||||
let classifierUrl: URL
|
||||
|
||||
func getClassifierVersion() -> Int? {
|
||||
let group = DispatchGroup()
|
||||
group.enter()
|
||||
let classifierVersionUrl = server.appendingPathComponent("version")
|
||||
guard let data = wait(for: URLRequest(url: classifierVersionUrl)) else {
|
||||
let df = DateFormatter()
|
||||
|
||||
|
||||
// MARK: Step 1: Load configuration
|
||||
|
||||
init?() {
|
||||
guard let configuration = Configuration(at: ClassifierCreator.configurationFileUrl) else {
|
||||
return nil
|
||||
}
|
||||
guard let string = String(data: data, encoding: .utf8) else {
|
||||
print("[ERROR] Invalid classifier version \(data)")
|
||||
self.configuration = configuration
|
||||
guard let serverUrl = URL(string: configuration.serverPath) else {
|
||||
print("[ERROR] Configuration: Invalid server path \(configuration.serverPath)")
|
||||
return nil
|
||||
}
|
||||
guard let int = Int(string) else {
|
||||
self.server = serverUrl
|
||||
self.imageDirectory = URL(fileURLWithPath: configuration.imageDirectory)
|
||||
self.classifierUrl = URL(fileURLWithPath: configuration.classifierModelPath)
|
||||
df.dateFormat = "yy-MM-dd-HH-mm-ss"
|
||||
}
|
||||
|
||||
// MARK: Main function
|
||||
|
||||
func run() async {
|
||||
let imagesSnapshotDate = Date()
|
||||
guard let (classes, changedImageCount) = await loadImages() else {
|
||||
return
|
||||
}
|
||||
|
||||
guard !classes.isEmpty else {
|
||||
return
|
||||
}
|
||||
|
||||
guard changedImageCount > 0 else {
|
||||
print("[INFO] No changed images, so no new classifier trained")
|
||||
return
|
||||
}
|
||||
|
||||
guard let classifierVersion = await getClassifierVersion() else {
|
||||
return
|
||||
}
|
||||
let newVersion = classifierVersion + 1
|
||||
|
||||
print("[INFO] Image directory: \(imageDirectory.absoluteURL.path)")
|
||||
print("[INFO] Model path: \(configuration.classifierModelPath)")
|
||||
print("[INFO] Version: \(newVersion)")
|
||||
print("[INFO] Classes: \(classes.count)")
|
||||
print("[INFO] Iterations: \(configuration.trainingIterations)")
|
||||
|
||||
guard trainModel() else {
|
||||
return
|
||||
}
|
||||
|
||||
guard await upload(version: newVersion) else {
|
||||
return
|
||||
}
|
||||
|
||||
guard await upload(classes: classes, lastUpdate: imagesSnapshotDate) else {
|
||||
return
|
||||
}
|
||||
|
||||
print("[INFO] Done")
|
||||
}
|
||||
|
||||
// MARK: Step 2: Load changed images
|
||||
|
||||
func loadImages() async -> (classes: [Int], changedImages: Int)? {
|
||||
guard createImageFolderIfMissing() else {
|
||||
return nil
|
||||
}
|
||||
let imageCounts = await getImageCounts()
|
||||
let missingImageList: [(cap: Int, image: Int)] = imageCounts
|
||||
.sorted { $0.key < $1.key }
|
||||
.reduce(into: []) { list, pair in
|
||||
let missingImagesForCap: [(cap: Int, image: Int)] = (0..<pair.value).compactMap { image in
|
||||
let url = imageUrl(base: imageDirectory, cap: pair.key, image: image)
|
||||
guard !FileManager.default.fileExists(atPath: url.path) else {
|
||||
return nil
|
||||
}
|
||||
return (cap: pair.key, image: image)
|
||||
}
|
||||
list.append(contentsOf: missingImagesForCap)
|
||||
}
|
||||
if missingImageList.isEmpty {
|
||||
print("[INFO] No missing images to load")
|
||||
} else {
|
||||
print("[INFO] Loading \(missingImageList.count) missing images...")
|
||||
}
|
||||
guard await loadImages(missingImageList) else {
|
||||
return nil
|
||||
}
|
||||
let changedImageList = await getChangedImageList()
|
||||
|
||||
if changedImageList.isEmpty {
|
||||
print("[INFO] No changed images to load")
|
||||
} else {
|
||||
print("[INFO] Loading \(changedImageList.count) changed images...")
|
||||
}
|
||||
guard await loadImages(changedImageList) else {
|
||||
return nil
|
||||
}
|
||||
let classes = imageCounts.keys.sorted()
|
||||
return (classes, missingImageList.count + changedImageList.count)
|
||||
}
|
||||
|
||||
private func createImageFolderIfMissing() -> Bool {
|
||||
guard !FileManager.default.fileExists(atPath: imageDirectory.path) else {
|
||||
return true
|
||||
}
|
||||
do {
|
||||
try FileManager.default.createDirectory(at: imageDirectory, withIntermediateDirectories: true)
|
||||
return true
|
||||
} catch {
|
||||
print("[ERROR] Failed to create image directory: \(error)")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
private func getImageCounts() async -> [Int : Int] {
|
||||
guard let data: Data = await get(server.appendingPathComponent("caps.json")) else {
|
||||
return [:]
|
||||
}
|
||||
do {
|
||||
return try JSONDecoder().decode([Cap].self, from: data)
|
||||
.reduce(into: [:]) { $0[$1.id] = $1.count }
|
||||
} catch {
|
||||
print("[ERROR] Failed to decode cap database: \(error)")
|
||||
return [:]
|
||||
}
|
||||
}
|
||||
|
||||
private func imageUrl(base: URL, cap: Int, image: Int) -> URL {
|
||||
base.appendingPathComponent(String(format: "%04d/%04d-%02d.jpg", cap, cap, image))
|
||||
}
|
||||
|
||||
private func loadImage(cap: Int, image: Int) async -> Bool {
|
||||
let url = imageUrl(base: server.appendingPathComponent("images"), cap: cap, image: image)
|
||||
let tempFile: URL, response: URLResponse
|
||||
do {
|
||||
(tempFile, response) = try await URLSession.shared.download(from: url)
|
||||
} catch {
|
||||
print("[ERROR] Failed to load image \(image) of cap \(cap): \(error)")
|
||||
return false
|
||||
}
|
||||
let responseCode = (response as! HTTPURLResponse).statusCode
|
||||
guard responseCode == 200 else {
|
||||
print("[ERROR] Failed to load image \(image) of cap \(cap): Response \(responseCode)")
|
||||
return false
|
||||
}
|
||||
do {
|
||||
let localUrl = imageUrl(base: imageDirectory, cap: cap, image: image)
|
||||
if FileManager.default.fileExists(atPath: localUrl.path) {
|
||||
try FileManager.default.removeItem(at: localUrl)
|
||||
}
|
||||
try FileManager.default.moveItem(at: tempFile, to: localUrl)
|
||||
return true
|
||||
} catch {
|
||||
print("[ERROR] Failed to save image \(image) of cap \(cap): \(error)")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
private func loadImages(_ list: [(cap: Int, image: Int)]) async -> Bool {
|
||||
guard !list.isEmpty else {
|
||||
return true
|
||||
}
|
||||
var loadedImages = 0
|
||||
await withTaskGroup(of: Bool.self) { group in
|
||||
for (cap, image) in list {
|
||||
group.addTask {
|
||||
await self.loadImage(cap: cap, image: image)
|
||||
}
|
||||
}
|
||||
for await loaded in group {
|
||||
if loaded {
|
||||
loadedImages += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
if loadedImages != list.count {
|
||||
print("[ERROR] Only \(loadedImages) of \(list.count) images loaded")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func getChangedImageList() async -> [(cap: Int, image: Int)] {
|
||||
guard let string: String = await get(server.appendingPathComponent("changes.txt")) else {
|
||||
print("[ERROR] Failed to get list of changed images")
|
||||
return []
|
||||
}
|
||||
|
||||
return string
|
||||
.components(separatedBy: "\n")
|
||||
.filter { $0 != "" }
|
||||
.compactMap {
|
||||
let parts = $0.components(separatedBy: ":")
|
||||
guard parts.count == 3 else {
|
||||
return nil
|
||||
}
|
||||
/*
|
||||
guard let date = df.date(from: parts[0]) else {
|
||||
print("[WARN] Invalid date \(parts[0]) in change list")
|
||||
return nil
|
||||
}
|
||||
*/
|
||||
guard let cap = Int(parts[1]) else {
|
||||
print("[WARN] Invalid cap id \(parts[1]) in change list")
|
||||
return nil
|
||||
}
|
||||
guard let image = Int(parts[2]) else {
|
||||
print("[WARN] Invalid image id \(parts[2]) in change list")
|
||||
return nil
|
||||
}
|
||||
return (cap, image)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: Step 3: Compute version
|
||||
|
||||
func getClassifierVersion() async -> Int? {
|
||||
guard let string: String = await get(server.appendingPathComponent("version")) else {
|
||||
print("[ERROR] Failed to get classifier version")
|
||||
return nil
|
||||
}
|
||||
guard let version = Int(string) else {
|
||||
print("[ERROR] Invalid classifier version \(string)")
|
||||
return nil
|
||||
}
|
||||
return int
|
||||
return version
|
||||
}
|
||||
|
||||
func upload(classifier: Data, version: Int) -> Bool {
|
||||
let classifierUrl = server
|
||||
.appendingPathComponent("classifier")
|
||||
.appendingPathComponent("\(version)")
|
||||
var request = URLRequest(url: classifierUrl)
|
||||
// MARK: Step 4: Train classifier
|
||||
|
||||
func trainModel() -> Bool {
|
||||
var params = MLImageClassifier.ModelParameters(augmentation: [])
|
||||
params.maxIterations = configuration.trainingIterations
|
||||
let model: MLImageClassifier
|
||||
do {
|
||||
model = try MLImageClassifier(
|
||||
trainingData: .labeledDirectories(at: imageDirectory),
|
||||
parameters: params)
|
||||
} catch {
|
||||
print("[ERROR] Failed to create classifier: \(error)")
|
||||
return false
|
||||
}
|
||||
|
||||
print("[INFO] Saving classifier...")
|
||||
do {
|
||||
try model.write(to: classifierUrl)
|
||||
return true
|
||||
} catch {
|
||||
print("[ERROR] Failed to save model to file: \(error)")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: Step 5: Upload classifier
|
||||
|
||||
func upload(version: Int) async -> Bool {
|
||||
print("[INFO] Uploading classifier...")
|
||||
let modelData: Data
|
||||
do {
|
||||
modelData = try Data(contentsOf: classifierUrl)
|
||||
} catch {
|
||||
print("[ERROR] Failed to read classifier data: \(error)")
|
||||
return false
|
||||
}
|
||||
|
||||
return await post(
|
||||
url: server.appendingPathComponent("classifier/\(version)"),
|
||||
body: modelData)
|
||||
}
|
||||
|
||||
// MARK: Step 6: Update classes
|
||||
|
||||
func upload(classes: [Int], lastUpdate: Date) async -> Bool {
|
||||
print("[INFO] Uploading trained classes...")
|
||||
let dateString = df.string(from: lastUpdate)
|
||||
return await post(
|
||||
url: server.appendingPathComponent("classes/\(dateString)"),
|
||||
body: classes.map(String.init).joined(separator: "\n").data(using: .utf8)!,
|
||||
dateKey: lastUpdate)
|
||||
}
|
||||
|
||||
// MARK: Requests
|
||||
|
||||
private func post(url: URL, body: Data) async -> Bool {
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "POST"
|
||||
request.httpBody = classifier
|
||||
request.addValue(authentication, forHTTPHeaderField: "key")
|
||||
return wait(for: request) != nil
|
||||
request.httpBody = body
|
||||
request.addValue(configuration.authenticationToken, forHTTPHeaderField: "key")
|
||||
return await perform(request) != nil
|
||||
}
|
||||
|
||||
func upload(classes: [Int]) -> Bool {
|
||||
let classifierUrl = server
|
||||
.appendingPathComponent("classes")
|
||||
var request = URLRequest(url: classifierUrl)
|
||||
request.httpMethod = "POST"
|
||||
request.httpBody = classes.map(String.init).joined(separator: "\n").data(using: .utf8)!
|
||||
request.addValue(authentication, forHTTPHeaderField: "key")
|
||||
return wait(for: request) != nil
|
||||
private func perform(_ request: URLRequest) async -> Data? {
|
||||
let data: Data
|
||||
let response: URLResponse
|
||||
do {
|
||||
(data, response) = try await URLSession.shared.data(for: request)
|
||||
} catch {
|
||||
print("[ERROR] Request failed: \(error)")
|
||||
return nil
|
||||
}
|
||||
let code = (response as! HTTPURLResponse).statusCode
|
||||
guard code == 200 else {
|
||||
print("[ERROR] Invalid response \(code)")
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
private func get(_ url: URL) async -> Data? {
|
||||
await perform(URLRequest(url: url))
|
||||
}
|
||||
|
||||
private func get(_ url: URL) async -> String? {
|
||||
guard let data: Data = await get(url) else {
|
||||
return nil
|
||||
}
|
||||
guard let string = String(data: data, encoding: .utf8) else {
|
||||
print("[ERROR] Invalid string response \(data)")
|
||||
return nil
|
||||
}
|
||||
return string
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let count = CommandLine.argc
|
||||
guard count == 6 else {
|
||||
print("[ERROR] Invalid number of arguments")
|
||||
exit(1)
|
||||
}
|
||||
let serverPath = CommandLine.arguments[1]
|
||||
let authenticationKey = CommandLine.arguments[2]
|
||||
let imageDirectory = URL(fileURLWithPath: CommandLine.arguments[3])
|
||||
let classifierUrl = URL(fileURLWithPath: CommandLine.arguments[4])
|
||||
let iterationsString = CommandLine.arguments[5]
|
||||
|
||||
guard let serverUrl = URL(string: serverPath) else {
|
||||
print("[ERROR] Invalid server path argument")
|
||||
exit(1)
|
||||
}
|
||||
guard let iterations = Int(iterationsString) else {
|
||||
print("[ERROR] Invalid iterations argument")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
let server = Server(server: serverUrl, authentication: authenticationKey)
|
||||
|
||||
let classes: [Int]
|
||||
|
||||
do {
|
||||
classes = try FileManager.default.contentsOfDirectory(atPath: imageDirectory.path)
|
||||
.compactMap(Int.init)
|
||||
} catch {
|
||||
print("[ERROR] Failed to get model classes: \(error)")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
guard let oldVersion = server.getClassifierVersion() else {
|
||||
print("[ERROR] Failed to get classifier version")
|
||||
exit(1)
|
||||
}
|
||||
let newVersion = oldVersion + 1
|
||||
|
||||
print("[INFO] Image directory: \(imageDirectory.path)")
|
||||
print("[INFO] Model path: \(classifierUrl.path)")
|
||||
print("[INFO] Version: \(newVersion)")
|
||||
print("[INFO] Classes: \(classes.count)")
|
||||
print("[INFO] Iterations: \(iterations)")
|
||||
|
||||
var params = MLImageClassifier.ModelParameters(augmentation: [])
|
||||
params.maxIterations = iterations
|
||||
|
||||
let model: MLImageClassifier
|
||||
do {
|
||||
model = try MLImageClassifier(
|
||||
trainingData: .labeledDirectories(at: imageDirectory),
|
||||
parameters: params)
|
||||
} catch {
|
||||
print("[ERROR] Failed to create classifier: \(error)")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
print("[INFO] Saving classifier...")
|
||||
do {
|
||||
try model.write(to: classifierUrl)
|
||||
} catch {
|
||||
print("[ERROR] Failed to save model to file: \(error)")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
print("[INFO] Uploading classifier...")
|
||||
let modelData = try Data(contentsOf: classifierUrl)
|
||||
guard server.upload(classifier: modelData, version: newVersion) else {
|
||||
print("[ERROR] Failed to upload classifier")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
print("[INFO] Uploading trained classes...")
|
||||
guard server.upload(classes: classes) else {
|
||||
print("[ERROR] Failed to upload classes")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
print("[INFO] Done")
|
||||
exit(0)
|
||||
await ClassifierCreator().run()
|
||||
|
Loading…
Reference in New Issue
Block a user