Switch to pure Swift training script
This commit is contained in:
parent
299dba0da4
commit
d516c1acd6
2
.gitignore
vendored
2
.gitignore
vendored
@ -13,5 +13,5 @@ Public/classifier.version
|
|||||||
Public/classifier.mlmodel
|
Public/classifier.mlmodel
|
||||||
Public/caps.json
|
Public/caps.json
|
||||||
Training/backup/
|
Training/backup/
|
||||||
Training/config.sh
|
Training/config.json
|
||||||
Public/thumbnails
|
Public/thumbnails
|
||||||
|
@ -22,6 +22,8 @@ final class CapServer: ServerOwner {
|
|||||||
|
|
||||||
private let fm = FileManager.default
|
private let fm = FileManager.default
|
||||||
|
|
||||||
|
private let changedImageEntryDateFormatter: DateFormatter
|
||||||
|
|
||||||
// MARK: Caps
|
// MARK: Caps
|
||||||
|
|
||||||
private var writers: Set<String>
|
private var writers: Set<String>
|
||||||
@ -77,6 +79,8 @@ final class CapServer: ServerOwner {
|
|||||||
self.classifierFile = folder.appendingPathComponent("classifier.mlmodel")
|
self.classifierFile = folder.appendingPathComponent("classifier.mlmodel")
|
||||||
self.changedImagesFile = folder.appendingPathComponent("changes.txt")
|
self.changedImagesFile = folder.appendingPathComponent("changes.txt")
|
||||||
self.writers = Set(writers)
|
self.writers = Set(writers)
|
||||||
|
self.changedImageEntryDateFormatter = DateFormatter()
|
||||||
|
changedImageEntryDateFormatter.dateFormat = "yy-MM-dd-HH-mm-ss"
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadData() throws {
|
func loadData() throws {
|
||||||
@ -303,9 +307,7 @@ final class CapServer: ServerOwner {
|
|||||||
unwrittenImageChanges = entries
|
unwrittenImageChanges = entries
|
||||||
try? handle.close()
|
try? handle.close()
|
||||||
}
|
}
|
||||||
let df = DateFormatter()
|
let dateString = changedImageEntryDateFormatter.string(from: Date())
|
||||||
df.dateFormat = "yy-MM-dd-HH-mm-ss"
|
|
||||||
let dateString = df.string(from: Date())
|
|
||||||
while let entry = entries.popLast() {
|
while let entry = entries.popLast() {
|
||||||
let content = "\(dateString):\(entry.cap):\(entry.image)\n".data(using: .utf8)!
|
let content = "\(dateString):\(entry.cap):\(entry.image)\n".data(using: .utf8)!
|
||||||
try handle.write(contentsOf: content)
|
try handle.write(contentsOf: content)
|
||||||
@ -333,11 +335,25 @@ final class CapServer: ServerOwner {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func emptyChangedImageListFile() {
|
func removeAllEntriesInImageChangeList(before date: Date) {
|
||||||
do {
|
do {
|
||||||
try Data().write(to: changedImagesFile)
|
try String(contentsOf: changedImagesFile)
|
||||||
|
.components(separatedBy: "\n")
|
||||||
|
.filter { $0 != "" }
|
||||||
|
.compactMap { line -> String? in
|
||||||
|
guard let entryDate = changedImageEntryDateFormatter.date(from: line.components(separatedBy: ":").first!) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
guard entryDate > date else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
.joined(separator: "\n")
|
||||||
|
.data(using: .utf8)!
|
||||||
|
.write(to: changedImagesFile)
|
||||||
} catch {
|
} catch {
|
||||||
log("Failed to empty changed images file: \(error)")
|
log("Failed to update changed images file: \(error)")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,6 +4,13 @@ import Foundation
|
|||||||
/// The decoder to extract caps from JSON payloads given to the `cap` route.
|
/// The decoder to extract caps from JSON payloads given to the `cap` route.
|
||||||
private let decoder = JSONDecoder()
|
private let decoder = JSONDecoder()
|
||||||
|
|
||||||
|
/// The date formatter to decode dates in requests
|
||||||
|
private let dateFormatter: DateFormatter = {
|
||||||
|
let df = DateFormatter()
|
||||||
|
df.dateFormat = "yy-MM-dd-HH-mm-ss"
|
||||||
|
return df
|
||||||
|
}()
|
||||||
|
|
||||||
private func authorize(_ request: Request) throws {
|
private func authorize(_ request: Request) throws {
|
||||||
guard let key = request.headers.first(name: "key") else {
|
guard let key = request.headers.first(name: "key") else {
|
||||||
throw Abort(.badRequest) // 400
|
throw Abort(.badRequest) // 400
|
||||||
@ -65,7 +72,16 @@ func routes(_ app: Application) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update the trained classes
|
// Update the trained classes
|
||||||
app.postCatching("classes") { request in
|
app.postCatching("classes", ":date") { request in
|
||||||
|
guard let dateString = request.parameters.get("date") else {
|
||||||
|
log("Invalid parameter for date")
|
||||||
|
throw Abort(.badRequest)
|
||||||
|
}
|
||||||
|
guard let date = dateFormatter.date(from: dateString) else {
|
||||||
|
log("Invalid date specification")
|
||||||
|
throw Abort(.badRequest)
|
||||||
|
}
|
||||||
|
|
||||||
try authorize(request)
|
try authorize(request)
|
||||||
guard let buffer = request.body.data else {
|
guard let buffer = request.body.data else {
|
||||||
log("Missing body data: \(request.body.description)")
|
log("Missing body data: \(request.body.description)")
|
||||||
@ -77,6 +93,6 @@ func routes(_ app: Application) {
|
|||||||
throw CapError.invalidBody
|
throw CapError.invalidBody
|
||||||
}
|
}
|
||||||
server.updateTrainedClasses(content: content)
|
server.updateTrainedClasses(content: content)
|
||||||
server.emptyChangedImageListFile()
|
server.removeAllEntriesInImageChangeList(before: date)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
7
Training/config_example.json
Normal file
7
Training/config_example.json
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"imageDirectory": "../Public/images",
|
||||||
|
"classifierModelPath": "../Public/classifier.mlmodel",
|
||||||
|
"trainingIterations": 20,
|
||||||
|
"serverPath": "https://mydomain.com/caps",
|
||||||
|
"authenticationToken": "mysecretkey",
|
||||||
|
}
|
@ -1,10 +0,0 @@
|
|||||||
IMAGE_DIR="../Public/images"
|
|
||||||
MODEL_FILE="../Public/classifier.mlmodel"
|
|
||||||
|
|
||||||
TRAINING_ITERATIONS="20"
|
|
||||||
SSH_PORT="22"
|
|
||||||
SERVER="pi@mydomain.com"
|
|
||||||
SERVER_ROOT_PATH="/caps/Public"
|
|
||||||
SERVER_PATH="https://mydomain.com/caps"
|
|
||||||
SERVER_AUTH="mysecretkey"
|
|
||||||
|
|
@ -1,48 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
########### SUDO COMMANDS WITHOUT PASSWORD #############################################
|
|
||||||
#
|
|
||||||
# Specific commands must be executable without sudo
|
|
||||||
#
|
|
||||||
# Add to sudoers file (sudo visudo):
|
|
||||||
#
|
|
||||||
# Disable password for specific commands
|
|
||||||
# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/chmod -R 755 /data/public/capserver/images
|
|
||||||
# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/mv /home/pi/classifier.mlmodel /data/public/capserver/
|
|
||||||
# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/mv /home/pi/classifier.version /data/public/capserver/
|
|
||||||
# pi ALL=(ALL:ALL) NOPASSWD: /usr/bin/chown -R www-data\:www-data /data/public/capserver/
|
|
||||||
#
|
|
||||||
########################################################################################
|
|
||||||
|
|
||||||
########### EXPLANATION OF RSYNC FLAGS #################################################
|
|
||||||
#
|
|
||||||
# -h human readable output
|
|
||||||
# -v verbose output
|
|
||||||
# -r recursive
|
|
||||||
# -P print information about long-running transfers, keep partial files (—-progress --partial)
|
|
||||||
# -t preserves modification times
|
|
||||||
# -u only update if newer
|
|
||||||
#
|
|
||||||
########################################################################################
|
|
||||||
|
|
||||||
source config.sh
|
|
||||||
|
|
||||||
echo "[INFO] Ensuring permissions for images on server..."
|
|
||||||
ssh -p $SSH_PORT ${SERVER} "sudo chmod -R 755 ${SERVER_ROOT_PATH}/images"
|
|
||||||
|
|
||||||
retVal=$?
|
|
||||||
if [ $retVal -ne 0 ]; then
|
|
||||||
echo '[ERROR] Failed to change image permissions'
|
|
||||||
return $retVal
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "[INFO] Transferring images from server..."
|
|
||||||
rsync -hrut --info=progress2 -e "ssh -p ${SSH_PORT}" ${SERVER}:/${SERVER_ROOT_PATH}/images/ "${IMAGE_DIR}"
|
|
||||||
|
|
||||||
retVal=$?
|
|
||||||
if [ $retVal -ne 0 ]; then
|
|
||||||
echo '[ERROR] Failed to transfer images from server'
|
|
||||||
return $retVal
|
|
||||||
fi
|
|
||||||
|
|
||||||
swift train.swift $SERVER_PATH $SERVER_AUTH $IMAGE_DIR $MODEL_FILE $TRAINING_ITERATIONS
|
|
@ -2,126 +2,289 @@ import Foundation
|
|||||||
import Cocoa
|
import Cocoa
|
||||||
import CreateML
|
import CreateML
|
||||||
|
|
||||||
final class Server {
|
struct Configuration: Codable {
|
||||||
|
|
||||||
|
let imageDirectory: String
|
||||||
|
|
||||||
|
let classifierModelPath: String
|
||||||
|
|
||||||
|
let trainingIterations: Int
|
||||||
|
|
||||||
|
let serverPath: String
|
||||||
|
|
||||||
|
let authenticationToken: String
|
||||||
|
|
||||||
|
init?(at url: URL) {
|
||||||
|
do {
|
||||||
|
let configData = try Data(contentsOf: url)
|
||||||
|
self = try JSONDecoder().decode(Configuration.self, from: configData)
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to load configuration at \(url.absoluteURL.path): \(error)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Cap: Codable {
|
||||||
|
|
||||||
|
let id: Int
|
||||||
|
|
||||||
|
let count: Int
|
||||||
|
|
||||||
|
enum CodingKeys: String, CodingKey {
|
||||||
|
case id = "i"
|
||||||
|
case count = "c"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
final class ClassifierCreator {
|
||||||
|
|
||||||
|
static let configurationFileUrl = URL(fileURLWithPath: "config.json")
|
||||||
|
|
||||||
let server: URL
|
let server: URL
|
||||||
|
|
||||||
let authentication: String
|
let configuration: Configuration
|
||||||
|
|
||||||
init(server: URL, authentication: String) {
|
let imageDirectory: URL
|
||||||
self.server = server
|
|
||||||
self.authentication = authentication
|
|
||||||
}
|
|
||||||
|
|
||||||
private func wait(for request: URLRequest) -> Data? {
|
let classifierUrl: URL
|
||||||
let group = DispatchGroup()
|
|
||||||
group.enter()
|
|
||||||
var result: Data? = nil
|
|
||||||
URLSession.shared.dataTask(with: request) { data, response, _ in
|
|
||||||
defer { group.leave() }
|
|
||||||
let code = (response as! HTTPURLResponse).statusCode
|
|
||||||
guard code == 200 else {
|
|
||||||
print("[ERROR] Invalid response \(code)")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
guard let data = data else {
|
|
||||||
print("[ERROR] No response data")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
result = data
|
|
||||||
}.resume()
|
|
||||||
group.wait()
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func getClassifierVersion() -> Int? {
|
let df = DateFormatter()
|
||||||
let group = DispatchGroup()
|
|
||||||
group.enter()
|
|
||||||
let classifierVersionUrl = server.appendingPathComponent("version")
|
// MARK: Step 1: Load configuration
|
||||||
guard let data = wait(for: URLRequest(url: classifierVersionUrl)) else {
|
|
||||||
|
init?() {
|
||||||
|
guard let configuration = Configuration(at: ClassifierCreator.configurationFileUrl) else {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
guard let string = String(data: data, encoding: .utf8) else {
|
self.configuration = configuration
|
||||||
print("[ERROR] Invalid classifier version \(data)")
|
guard let serverUrl = URL(string: configuration.serverPath) else {
|
||||||
|
print("[ERROR] Configuration: Invalid server path \(configuration.serverPath)")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
guard let int = Int(string) else {
|
self.server = serverUrl
|
||||||
|
self.imageDirectory = URL(fileURLWithPath: configuration.imageDirectory)
|
||||||
|
self.classifierUrl = URL(fileURLWithPath: configuration.classifierModelPath)
|
||||||
|
df.dateFormat = "yy-MM-dd-HH-mm-ss"
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: Main function
|
||||||
|
|
||||||
|
func run() async {
|
||||||
|
let imagesSnapshotDate = Date()
|
||||||
|
guard let (classes, changedImageCount) = await loadImages() else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
guard !classes.isEmpty else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
guard changedImageCount > 0 else {
|
||||||
|
print("[INFO] No changed images, so no new classifier trained")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
guard let classifierVersion = await getClassifierVersion() else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
let newVersion = classifierVersion + 1
|
||||||
|
|
||||||
|
print("[INFO] Image directory: \(imageDirectory.absoluteURL.path)")
|
||||||
|
print("[INFO] Model path: \(configuration.classifierModelPath)")
|
||||||
|
print("[INFO] Version: \(newVersion)")
|
||||||
|
print("[INFO] Classes: \(classes.count)")
|
||||||
|
print("[INFO] Iterations: \(configuration.trainingIterations)")
|
||||||
|
|
||||||
|
guard trainModel() else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
guard await upload(version: newVersion) else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
guard await upload(classes: classes, lastUpdate: imagesSnapshotDate) else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
print("[INFO] Done")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: Step 2: Load changed images
|
||||||
|
|
||||||
|
func loadImages() async -> (classes: [Int], changedImages: Int)? {
|
||||||
|
guard createImageFolderIfMissing() else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
let imageCounts = await getImageCounts()
|
||||||
|
let missingImageList: [(cap: Int, image: Int)] = imageCounts
|
||||||
|
.sorted { $0.key < $1.key }
|
||||||
|
.reduce(into: []) { list, pair in
|
||||||
|
let missingImagesForCap: [(cap: Int, image: Int)] = (0..<pair.value).compactMap { image in
|
||||||
|
let url = imageUrl(base: imageDirectory, cap: pair.key, image: image)
|
||||||
|
guard !FileManager.default.fileExists(atPath: url.path) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return (cap: pair.key, image: image)
|
||||||
|
}
|
||||||
|
list.append(contentsOf: missingImagesForCap)
|
||||||
|
}
|
||||||
|
if missingImageList.isEmpty {
|
||||||
|
print("[INFO] No missing images to load")
|
||||||
|
} else {
|
||||||
|
print("[INFO] Loading \(missingImageList.count) missing images...")
|
||||||
|
}
|
||||||
|
guard await loadImages(missingImageList) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
let changedImageList = await getChangedImageList()
|
||||||
|
|
||||||
|
if changedImageList.isEmpty {
|
||||||
|
print("[INFO] No changed images to load")
|
||||||
|
} else {
|
||||||
|
print("[INFO] Loading \(changedImageList.count) changed images...")
|
||||||
|
}
|
||||||
|
guard await loadImages(changedImageList) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
let classes = imageCounts.keys.sorted()
|
||||||
|
return (classes, missingImageList.count + changedImageList.count)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func createImageFolderIfMissing() -> Bool {
|
||||||
|
guard !FileManager.default.fileExists(atPath: imageDirectory.path) else {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
do {
|
||||||
|
try FileManager.default.createDirectory(at: imageDirectory, withIntermediateDirectories: true)
|
||||||
|
return true
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to create image directory: \(error)")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func getImageCounts() async -> [Int : Int] {
|
||||||
|
guard let data: Data = await get(server.appendingPathComponent("caps.json")) else {
|
||||||
|
return [:]
|
||||||
|
}
|
||||||
|
do {
|
||||||
|
return try JSONDecoder().decode([Cap].self, from: data)
|
||||||
|
.reduce(into: [:]) { $0[$1.id] = $1.count }
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to decode cap database: \(error)")
|
||||||
|
return [:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func imageUrl(base: URL, cap: Int, image: Int) -> URL {
|
||||||
|
base.appendingPathComponent(String(format: "%04d/%04d-%02d.jpg", cap, cap, image))
|
||||||
|
}
|
||||||
|
|
||||||
|
private func loadImage(cap: Int, image: Int) async -> Bool {
|
||||||
|
let url = imageUrl(base: server.appendingPathComponent("images"), cap: cap, image: image)
|
||||||
|
let tempFile: URL, response: URLResponse
|
||||||
|
do {
|
||||||
|
(tempFile, response) = try await URLSession.shared.download(from: url)
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to load image \(image) of cap \(cap): \(error)")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
let responseCode = (response as! HTTPURLResponse).statusCode
|
||||||
|
guard responseCode == 200 else {
|
||||||
|
print("[ERROR] Failed to load image \(image) of cap \(cap): Response \(responseCode)")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
do {
|
||||||
|
let localUrl = imageUrl(base: imageDirectory, cap: cap, image: image)
|
||||||
|
if FileManager.default.fileExists(atPath: localUrl.path) {
|
||||||
|
try FileManager.default.removeItem(at: localUrl)
|
||||||
|
}
|
||||||
|
try FileManager.default.moveItem(at: tempFile, to: localUrl)
|
||||||
|
return true
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to save image \(image) of cap \(cap): \(error)")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func loadImages(_ list: [(cap: Int, image: Int)]) async -> Bool {
|
||||||
|
guard !list.isEmpty else {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
var loadedImages = 0
|
||||||
|
await withTaskGroup(of: Bool.self) { group in
|
||||||
|
for (cap, image) in list {
|
||||||
|
group.addTask {
|
||||||
|
await self.loadImage(cap: cap, image: image)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for await loaded in group {
|
||||||
|
if loaded {
|
||||||
|
loadedImages += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if loadedImages != list.count {
|
||||||
|
print("[ERROR] Only \(loadedImages) of \(list.count) images loaded")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func getChangedImageList() async -> [(cap: Int, image: Int)] {
|
||||||
|
guard let string: String = await get(server.appendingPathComponent("changes.txt")) else {
|
||||||
|
print("[ERROR] Failed to get list of changed images")
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
return string
|
||||||
|
.components(separatedBy: "\n")
|
||||||
|
.filter { $0 != "" }
|
||||||
|
.compactMap {
|
||||||
|
let parts = $0.components(separatedBy: ":")
|
||||||
|
guard parts.count == 3 else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
guard let date = df.date(from: parts[0]) else {
|
||||||
|
print("[WARN] Invalid date \(parts[0]) in change list")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
guard let cap = Int(parts[1]) else {
|
||||||
|
print("[WARN] Invalid cap id \(parts[1]) in change list")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
guard let image = Int(parts[2]) else {
|
||||||
|
print("[WARN] Invalid image id \(parts[2]) in change list")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return (cap, image)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: Step 3: Compute version
|
||||||
|
|
||||||
|
func getClassifierVersion() async -> Int? {
|
||||||
|
guard let string: String = await get(server.appendingPathComponent("version")) else {
|
||||||
|
print("[ERROR] Failed to get classifier version")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
guard let version = Int(string) else {
|
||||||
print("[ERROR] Invalid classifier version \(string)")
|
print("[ERROR] Invalid classifier version \(string)")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return int
|
return version
|
||||||
}
|
}
|
||||||
|
|
||||||
func upload(classifier: Data, version: Int) -> Bool {
|
// MARK: Step 4: Train classifier
|
||||||
let classifierUrl = server
|
|
||||||
.appendingPathComponent("classifier")
|
|
||||||
.appendingPathComponent("\(version)")
|
|
||||||
var request = URLRequest(url: classifierUrl)
|
|
||||||
request.httpMethod = "POST"
|
|
||||||
request.httpBody = classifier
|
|
||||||
request.addValue(authentication, forHTTPHeaderField: "key")
|
|
||||||
return wait(for: request) != nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func upload(classes: [Int]) -> Bool {
|
|
||||||
let classifierUrl = server
|
|
||||||
.appendingPathComponent("classes")
|
|
||||||
var request = URLRequest(url: classifierUrl)
|
|
||||||
request.httpMethod = "POST"
|
|
||||||
request.httpBody = classes.map(String.init).joined(separator: "\n").data(using: .utf8)!
|
|
||||||
request.addValue(authentication, forHTTPHeaderField: "key")
|
|
||||||
return wait(for: request) != nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
let count = CommandLine.argc
|
|
||||||
guard count == 6 else {
|
|
||||||
print("[ERROR] Invalid number of arguments")
|
|
||||||
exit(1)
|
|
||||||
}
|
|
||||||
let serverPath = CommandLine.arguments[1]
|
|
||||||
let authenticationKey = CommandLine.arguments[2]
|
|
||||||
let imageDirectory = URL(fileURLWithPath: CommandLine.arguments[3])
|
|
||||||
let classifierUrl = URL(fileURLWithPath: CommandLine.arguments[4])
|
|
||||||
let iterationsString = CommandLine.arguments[5]
|
|
||||||
|
|
||||||
guard let serverUrl = URL(string: serverPath) else {
|
|
||||||
print("[ERROR] Invalid server path argument")
|
|
||||||
exit(1)
|
|
||||||
}
|
|
||||||
guard let iterations = Int(iterationsString) else {
|
|
||||||
print("[ERROR] Invalid iterations argument")
|
|
||||||
exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
let server = Server(server: serverUrl, authentication: authenticationKey)
|
|
||||||
|
|
||||||
let classes: [Int]
|
|
||||||
|
|
||||||
do {
|
|
||||||
classes = try FileManager.default.contentsOfDirectory(atPath: imageDirectory.path)
|
|
||||||
.compactMap(Int.init)
|
|
||||||
} catch {
|
|
||||||
print("[ERROR] Failed to get model classes: \(error)")
|
|
||||||
exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
guard let oldVersion = server.getClassifierVersion() else {
|
|
||||||
print("[ERROR] Failed to get classifier version")
|
|
||||||
exit(1)
|
|
||||||
}
|
|
||||||
let newVersion = oldVersion + 1
|
|
||||||
|
|
||||||
print("[INFO] Image directory: \(imageDirectory.path)")
|
|
||||||
print("[INFO] Model path: \(classifierUrl.path)")
|
|
||||||
print("[INFO] Version: \(newVersion)")
|
|
||||||
print("[INFO] Classes: \(classes.count)")
|
|
||||||
print("[INFO] Iterations: \(iterations)")
|
|
||||||
|
|
||||||
|
func trainModel() -> Bool {
|
||||||
var params = MLImageClassifier.ModelParameters(augmentation: [])
|
var params = MLImageClassifier.ModelParameters(augmentation: [])
|
||||||
params.maxIterations = iterations
|
params.maxIterations = configuration.trainingIterations
|
||||||
|
|
||||||
let model: MLImageClassifier
|
let model: MLImageClassifier
|
||||||
do {
|
do {
|
||||||
model = try MLImageClassifier(
|
model = try MLImageClassifier(
|
||||||
@ -129,29 +292,88 @@ do {
|
|||||||
parameters: params)
|
parameters: params)
|
||||||
} catch {
|
} catch {
|
||||||
print("[ERROR] Failed to create classifier: \(error)")
|
print("[ERROR] Failed to create classifier: \(error)")
|
||||||
exit(1)
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
print("[INFO] Saving classifier...")
|
print("[INFO] Saving classifier...")
|
||||||
do {
|
do {
|
||||||
try model.write(to: classifierUrl)
|
try model.write(to: classifierUrl)
|
||||||
|
return true
|
||||||
} catch {
|
} catch {
|
||||||
print("[ERROR] Failed to save model to file: \(error)")
|
print("[ERROR] Failed to save model to file: \(error)")
|
||||||
exit(1)
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MARK: Step 5: Upload classifier
|
||||||
|
|
||||||
|
func upload(version: Int) async -> Bool {
|
||||||
print("[INFO] Uploading classifier...")
|
print("[INFO] Uploading classifier...")
|
||||||
let modelData = try Data(contentsOf: classifierUrl)
|
let modelData: Data
|
||||||
guard server.upload(classifier: modelData, version: newVersion) else {
|
do {
|
||||||
print("[ERROR] Failed to upload classifier")
|
modelData = try Data(contentsOf: classifierUrl)
|
||||||
exit(1)
|
} catch {
|
||||||
|
print("[ERROR] Failed to read classifier data: \(error)")
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return await post(
|
||||||
|
url: server.appendingPathComponent("classifier/\(version)"),
|
||||||
|
body: modelData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: Step 6: Update classes
|
||||||
|
|
||||||
|
func upload(classes: [Int], lastUpdate: Date) async -> Bool {
|
||||||
print("[INFO] Uploading trained classes...")
|
print("[INFO] Uploading trained classes...")
|
||||||
guard server.upload(classes: classes) else {
|
let dateString = df.string(from: lastUpdate)
|
||||||
print("[ERROR] Failed to upload classes")
|
return await post(
|
||||||
exit(1)
|
url: server.appendingPathComponent("classes/\(dateString)"),
|
||||||
|
body: classes.map(String.init).joined(separator: "\n").data(using: .utf8)!,
|
||||||
|
dateKey: lastUpdate)
|
||||||
}
|
}
|
||||||
|
|
||||||
print("[INFO] Done")
|
// MARK: Requests
|
||||||
exit(0)
|
|
||||||
|
private func post(url: URL, body: Data) async -> Bool {
|
||||||
|
var request = URLRequest(url: url)
|
||||||
|
request.httpMethod = "POST"
|
||||||
|
request.httpBody = body
|
||||||
|
request.addValue(configuration.authenticationToken, forHTTPHeaderField: "key")
|
||||||
|
return await perform(request) != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
private func perform(_ request: URLRequest) async -> Data? {
|
||||||
|
let data: Data
|
||||||
|
let response: URLResponse
|
||||||
|
do {
|
||||||
|
(data, response) = try await URLSession.shared.data(for: request)
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Request failed: \(error)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
let code = (response as! HTTPURLResponse).statusCode
|
||||||
|
guard code == 200 else {
|
||||||
|
print("[ERROR] Invalid response \(code)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
private func get(_ url: URL) async -> Data? {
|
||||||
|
await perform(URLRequest(url: url))
|
||||||
|
}
|
||||||
|
|
||||||
|
private func get(_ url: URL) async -> String? {
|
||||||
|
guard let data: Data = await get(url) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
guard let string = String(data: data, encoding: .utf8) else {
|
||||||
|
print("[ERROR] Invalid string response \(data)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await ClassifierCreator().run()
|
||||||
|
Loading…
Reference in New Issue
Block a user