Add code from training script
This commit is contained in:
parent
fd484243be
commit
9a50328b93
6
Config/config_example.json
Normal file
6
Config/config_example.json
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"contentDirectory": "../Public",
|
||||||
|
"trainingIterations": 20,
|
||||||
|
"serverPath": "https://mydomain.com/caps",
|
||||||
|
"authenticationToken": "mysecretkey",
|
||||||
|
}
|
@ -1,16 +1,14 @@
|
|||||||
// swift-tools-version: 5.9
|
// swift-tools-version: 5.9
|
||||||
// The swift-tools-version declares the minimum version of Swift required to build this package.
|
|
||||||
|
|
||||||
import PackageDescription
|
import PackageDescription
|
||||||
|
|
||||||
let package = Package(
|
let package = Package(
|
||||||
name: "Cap-Train",
|
name: "Cap-Train",
|
||||||
|
platforms: [.macOS(.v12)],
|
||||||
dependencies: [
|
dependencies: [
|
||||||
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.2.0"),
|
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.2.0"),
|
||||||
],
|
],
|
||||||
targets: [
|
targets: [
|
||||||
// Targets are the basic building blocks of a package, defining a module or a test suite.
|
|
||||||
// Targets can depend on other targets in this package and products from dependencies.
|
|
||||||
.executableTarget(
|
.executableTarget(
|
||||||
name: "Cap-Train",
|
name: "Cap-Train",
|
||||||
dependencies: [
|
dependencies: [
|
||||||
|
27
Sources/Cap.swift
Normal file
27
Sources/Cap.swift
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import Foundation
|
||||||
|
|
||||||
|
struct Cap: Codable {
|
||||||
|
|
||||||
|
let id: Int
|
||||||
|
|
||||||
|
let count: Int
|
||||||
|
|
||||||
|
enum CodingKeys: String, CodingKey {
|
||||||
|
case id = "i"
|
||||||
|
case count = "c"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CapImage: Equatable {
|
||||||
|
|
||||||
|
let cap: Int
|
||||||
|
|
||||||
|
let image: Int
|
||||||
|
}
|
||||||
|
|
||||||
|
extension CapImage: CustomStringConvertible {
|
||||||
|
|
||||||
|
var description: String {
|
||||||
|
"image \(image) of cap \(cap)"
|
||||||
|
}
|
||||||
|
}
|
18
Sources/CapTrain.swift
Normal file
18
Sources/CapTrain.swift
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import ArgumentParser
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
@main
|
||||||
|
struct CapTrain: AsyncParsableCommand {
|
||||||
|
|
||||||
|
@Argument(help: "The path to the configuration file")
|
||||||
|
var configPath: String
|
||||||
|
|
||||||
|
func run() async throws {
|
||||||
|
let configurationFileUrl = URL(fileURLWithPath: configPath)
|
||||||
|
|
||||||
|
let configuration = try Configuration(at: configurationFileUrl)
|
||||||
|
let creator = try ClassifierCreator(configuration: configuration)
|
||||||
|
|
||||||
|
try await creator.run()
|
||||||
|
}
|
||||||
|
}
|
@ -1,14 +0,0 @@
|
|||||||
// The Swift Programming Language
|
|
||||||
// https://docs.swift.org/swift-book
|
|
||||||
//
|
|
||||||
// Swift Argument Parser
|
|
||||||
// https://swiftpackageindex.com/apple/swift-argument-parser/documentation
|
|
||||||
|
|
||||||
import ArgumentParser
|
|
||||||
|
|
||||||
@main
|
|
||||||
struct Cap_Train: ParsableCommand {
|
|
||||||
mutating func run() throws {
|
|
||||||
print("Hello, world!")
|
|
||||||
}
|
|
||||||
}
|
|
463
Sources/ClassifierCreator.swift
Normal file
463
Sources/ClassifierCreator.swift
Normal file
@ -0,0 +1,463 @@
|
|||||||
|
import Foundation
|
||||||
|
import CreateML
|
||||||
|
|
||||||
|
final class ClassifierCreator {
|
||||||
|
|
||||||
|
let server: URL
|
||||||
|
|
||||||
|
let configuration: Configuration
|
||||||
|
|
||||||
|
let imageDirectory: URL
|
||||||
|
|
||||||
|
let thumbnailDirectory: URL
|
||||||
|
|
||||||
|
let classifierUrl: URL
|
||||||
|
|
||||||
|
let df = DateFormatter()
|
||||||
|
|
||||||
|
|
||||||
|
// MARK: Step 1: Load configuration
|
||||||
|
|
||||||
|
init(configuration: Configuration) throws {
|
||||||
|
self.configuration = configuration
|
||||||
|
self.server = try configuration.serverUrl()
|
||||||
|
let contentDirectory = URL(fileURLWithPath: configuration.contentFolder)
|
||||||
|
self.imageDirectory = contentDirectory.appendingPathComponent("images")
|
||||||
|
self.classifierUrl = contentDirectory.appendingPathComponent("classifier.mlmodel")
|
||||||
|
self.thumbnailDirectory = contentDirectory.appendingPathComponent("thumbnails")
|
||||||
|
df.dateFormat = "yy-MM-dd-HH-mm-ss"
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: Main function
|
||||||
|
|
||||||
|
func run() async throws {
|
||||||
|
let imagesSnapshotDate = Date()
|
||||||
|
let (classes, changedImageCount, changedMainImages) = try await loadImages()
|
||||||
|
|
||||||
|
guard !classes.isEmpty else {
|
||||||
|
print("[INFO] No image classes found, exiting...")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
guard changedImageCount > 0 else {
|
||||||
|
print("[INFO] No changed images, so no new classifier trained")
|
||||||
|
await createThumbnails(changed: changedMainImages)
|
||||||
|
print("[INFO] Done")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
let classifierVersion = try await getClassifierVersion()
|
||||||
|
let newVersion = classifierVersion + 1
|
||||||
|
|
||||||
|
print("[INFO] Image directory: \(imageDirectory.absoluteURL.path)")
|
||||||
|
print("[INFO] Model path: \(classifierUrl.path)")
|
||||||
|
print("[INFO] Version: \(newVersion)")
|
||||||
|
print("[INFO] Classes: \(classes.count)")
|
||||||
|
print("[INFO] Iterations: \(configuration.trainingIterations)")
|
||||||
|
|
||||||
|
try trainModel()
|
||||||
|
try await upload(version: newVersion)
|
||||||
|
try await upload(classes: classes, lastUpdate: imagesSnapshotDate)
|
||||||
|
await createThumbnails(changed: changedMainImages)
|
||||||
|
print("[INFO] Done")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: Step 2: Load changed images
|
||||||
|
|
||||||
|
func loadImages() async throws -> (classes: [Int], changedImageCount: Int, changedMainImages: [Int]) {
|
||||||
|
guard createFolderIfMissing(imageDirectory) else {
|
||||||
|
throw TrainingError.mainImageFolderNotCreated
|
||||||
|
}
|
||||||
|
let imageCounts = await getImageCounts()
|
||||||
|
let missingImageList: [CapImage] = imageCounts
|
||||||
|
.sorted { $0.key < $1.key }
|
||||||
|
.reduce(into: []) { list, pair in
|
||||||
|
let missingImagesForCap: [CapImage] = (0..<pair.value).compactMap { image in
|
||||||
|
let image = CapImage(cap: pair.key, image: image)
|
||||||
|
let url = imageUrl(base: imageDirectory, image: image)
|
||||||
|
guard !FileManager.default.fileExists(atPath: url.path) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return image
|
||||||
|
}
|
||||||
|
list.append(contentsOf: missingImagesForCap)
|
||||||
|
}
|
||||||
|
if missingImageList.isEmpty {
|
||||||
|
print("[INFO] No missing images to load")
|
||||||
|
} else {
|
||||||
|
print("[INFO] Loading \(missingImageList.count) missing images...")
|
||||||
|
try await loadImages(missingImageList)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
let changedImageList = try await getChangedImageList()
|
||||||
|
let filteredChangeList = changedImageList
|
||||||
|
.filter { $0.image < imageCounts[$0.cap] ?? 0 } // Filter non-existent images
|
||||||
|
.filter { !missingImageList.contains($0) }
|
||||||
|
|
||||||
|
let imagesAlreadyLoad = changedImageList.count - filteredChangeList.count
|
||||||
|
let suffix = imagesAlreadyLoad > 0 ? " (\(imagesAlreadyLoad) already loaded)" : ""
|
||||||
|
if filteredChangeList.isEmpty {
|
||||||
|
print("[INFO] No changed images to load" + suffix)
|
||||||
|
} else {
|
||||||
|
print("[INFO] Loading \(filteredChangeList.count) changed images" + suffix)
|
||||||
|
try await loadImages(filteredChangeList)
|
||||||
|
}
|
||||||
|
|
||||||
|
let changedMainImages = changedImageList.filter { $0.image == 0 }.map { $0.cap }
|
||||||
|
let classes = imageCounts.keys.sorted()
|
||||||
|
|
||||||
|
// Delete any image folders not present as caps
|
||||||
|
try deleteUnnecessaryImageFolders(caps: classes)
|
||||||
|
|
||||||
|
return (classes, missingImageList.count + changedImageList.count, changedMainImages)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func getImageCounts() async -> [Int : Int] {
|
||||||
|
guard let data: Data = await get(server.appendingPathComponent("caps.json")) else {
|
||||||
|
return [:]
|
||||||
|
}
|
||||||
|
do {
|
||||||
|
return try JSONDecoder().decode([Cap].self, from: data)
|
||||||
|
.reduce(into: [:]) { $0[$1.id] = $1.count }
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to decode cap database: \(error)")
|
||||||
|
return [:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func deleteUnnecessaryImageFolders(caps: [Int]) throws {
|
||||||
|
let validNames = caps.map { String(format: "%04d", $0) }
|
||||||
|
let folders: [String]
|
||||||
|
do {
|
||||||
|
folders = try FileManager.default.contentsOfDirectory(atPath: imageDirectory.path)
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to get list of image folders: \(error)")
|
||||||
|
throw TrainingError.failedToGetListOfImageFolders
|
||||||
|
}
|
||||||
|
for folder in folders {
|
||||||
|
if validNames.contains(folder) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not a valid cap folder
|
||||||
|
let url = imageDirectory.appendingPathComponent(folder)
|
||||||
|
do {
|
||||||
|
try FileManager.default.removeItem(at: url)
|
||||||
|
print("[INFO] Removed unused image folder '\(folder)'")
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to delete unused image folder \(folder): \(error)")
|
||||||
|
throw TrainingError.failedToRemoveImageFolder
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func imageUrl(base: URL, image: CapImage) -> URL {
|
||||||
|
base.appendingPathComponent(String(format: "%04d/%04d-%02d.jpg", image.cap, image.cap, image.image))
|
||||||
|
}
|
||||||
|
|
||||||
|
private func load(image: CapImage) async -> Bool {
|
||||||
|
guard createFolderIfMissing(imageDirectory.appendingPathComponent(String(format: "%04d", image.cap))) else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
let url = imageUrl(base: server.appendingPathComponent("images"), image: image)
|
||||||
|
let tempFile: URL, response: URLResponse
|
||||||
|
do {
|
||||||
|
(tempFile, response) = try await URLSession.shared.download(from: url)
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to load \(image): \(error)")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
let responseCode = (response as! HTTPURLResponse).statusCode
|
||||||
|
guard responseCode == 200 else {
|
||||||
|
print("[ERROR] Failed to load \(image): Response \(responseCode)")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
do {
|
||||||
|
let localUrl = imageUrl(base: imageDirectory, image: image)
|
||||||
|
if FileManager.default.fileExists(atPath: localUrl.path) {
|
||||||
|
try FileManager.default.removeItem(at: localUrl)
|
||||||
|
}
|
||||||
|
try FileManager.default.moveItem(at: tempFile, to: localUrl)
|
||||||
|
return true
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to save \(image): \(error)")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func loadImages(_ list: [CapImage]) async throws {
|
||||||
|
var loadedImages = 0
|
||||||
|
await withTaskGroup(of: Bool.self) { group in
|
||||||
|
for image in list {
|
||||||
|
group.addTask {
|
||||||
|
await self.load(image: image)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for await loaded in group {
|
||||||
|
if loaded {
|
||||||
|
loadedImages += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if loadedImages != list.count {
|
||||||
|
print("[ERROR] Only \(loadedImages) of \(list.count) images loaded")
|
||||||
|
throw TrainingError.failedToLoadImages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getChangedImageList() async throws -> [CapImage] {
|
||||||
|
guard let string: String = await get(server.appendingPathComponent("changes.txt")) else {
|
||||||
|
print("[ERROR] Failed to get list of changed images")
|
||||||
|
throw TrainingError.failedToGetChangedImagesList
|
||||||
|
}
|
||||||
|
|
||||||
|
return string
|
||||||
|
.components(separatedBy: "\n")
|
||||||
|
.filter { $0 != "" }
|
||||||
|
.compactMap {
|
||||||
|
let parts = $0.components(separatedBy: ":")
|
||||||
|
guard parts.count == 3 else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
guard let date = df.date(from: parts[0]) else {
|
||||||
|
print("[WARN] Invalid date \(parts[0]) in change list")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
guard let cap = Int(parts[1]) else {
|
||||||
|
print("[WARN] Invalid cap id \(parts[1]) in change list")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
guard let image = Int(parts[2]) else {
|
||||||
|
print("[WARN] Invalid image id \(parts[2]) in change list")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return CapImage(cap: cap, image: image)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: Step 3: Compute version
|
||||||
|
|
||||||
|
func getClassifierVersion() async throws -> Int {
|
||||||
|
guard let string: String = await get(server.appendingPathComponent("version")) else {
|
||||||
|
print("[ERROR] Failed to get classifier version")
|
||||||
|
throw TrainingError.failedToGetClassifierVersion
|
||||||
|
}
|
||||||
|
guard let version = Int(string) else {
|
||||||
|
print("[ERROR] Invalid classifier version \(string)")
|
||||||
|
throw TrainingError.invalidClassifierVersion(string)
|
||||||
|
}
|
||||||
|
return version
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: Step 4: Train classifier
|
||||||
|
|
||||||
|
func trainModel() throws {
|
||||||
|
var params = MLImageClassifier.ModelParameters(augmentation: [])
|
||||||
|
params.maxIterations = configuration.trainingIterations
|
||||||
|
let model: MLImageClassifier
|
||||||
|
do {
|
||||||
|
model = try MLImageClassifier(
|
||||||
|
trainingData: .labeledDirectories(at: imageDirectory),
|
||||||
|
parameters: params)
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to create classifier: \(error)")
|
||||||
|
throw TrainingError.failedToCreateClassifier(error)
|
||||||
|
}
|
||||||
|
|
||||||
|
print("[INFO] Saving classifier...")
|
||||||
|
do {
|
||||||
|
try model.write(to: classifierUrl)
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to save model to file: \(error)")
|
||||||
|
throw TrainingError.failedToWriteClassifier(error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: Step 5: Upload classifier
|
||||||
|
|
||||||
|
func upload(version: Int) async throws {
|
||||||
|
print("[INFO] Uploading classifier...")
|
||||||
|
let modelData: Data
|
||||||
|
do {
|
||||||
|
modelData = try Data(contentsOf: classifierUrl)
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to read classifier data: \(error)")
|
||||||
|
throw TrainingError.failedToReadClassifierData(error)
|
||||||
|
}
|
||||||
|
|
||||||
|
let success = await post(
|
||||||
|
url: server.appendingPathComponent("classifier/\(version)"),
|
||||||
|
body: modelData)
|
||||||
|
guard success else {
|
||||||
|
throw TrainingError.failedToUploadClassifier
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: Step 6: Update classes
|
||||||
|
|
||||||
|
func upload(classes: [Int], lastUpdate: Date) async throws {
|
||||||
|
print("[INFO] Uploading trained classes...")
|
||||||
|
let dateString = df.string(from: lastUpdate)
|
||||||
|
let url = server.appendingPathComponent("classes/\(dateString)")
|
||||||
|
let body = classes.map(String.init).joined(separator: ",").data(using: .utf8)!
|
||||||
|
guard await post(url: url, body: body) else {
|
||||||
|
throw TrainingError.failedToUploadClassifierClasses
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: Step 7: Create thumbnails
|
||||||
|
|
||||||
|
func createThumbnails(changed: [Int]) async {
|
||||||
|
guard checkMagickAvailability() else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
guard createFolderIfMissing(thumbnailDirectory) else {
|
||||||
|
print("[ERROR] Failed to create folder for thumbnails")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
let capIdsOfMissingThumbnails = await getMissingThumbnailIds()
|
||||||
|
let all = Set(capIdsOfMissingThumbnails).union(changed)
|
||||||
|
print("[INFO] Creating \(all.count) thumbnails...")
|
||||||
|
for cap in all {
|
||||||
|
await createThumbnail(for: cap)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkMagickAvailability() -> Bool {
|
||||||
|
do {
|
||||||
|
let (code, output) = try safeShell("magick --version")
|
||||||
|
guard code == 0, let version = output.components(separatedBy: "ImageMagick ").dropFirst().first?
|
||||||
|
.components(separatedBy: " ").first else {
|
||||||
|
print("[ERROR] Magick not found, install using 'brew install imagemagick'")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
print("[INFO] Using magick \(version)")
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to get version of magick: (\(error))")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
private func getMissingThumbnailIds() async -> [Int] {
|
||||||
|
guard let string: String = await get(server.appendingPathComponent("thumbnails/missing")) else {
|
||||||
|
print("[ERROR] Failed to get missing thumbnails")
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
return string.components(separatedBy: ",").compactMap(Int.init)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func createThumbnail(for cap: Int) async {
|
||||||
|
let mainImage = CapImage(cap: cap, image: 0)
|
||||||
|
let inputUrl = imageUrl(base: imageDirectory, image: mainImage)
|
||||||
|
guard FileManager.default.fileExists(atPath: inputUrl.path) else {
|
||||||
|
print("[ERROR] Local main image not found for cap \(cap): \(inputUrl.path)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
let output = thumbnailDirectory.appendingPathComponent(String(format: "%04d.jpg", cap))
|
||||||
|
do {
|
||||||
|
let command = "magick convert \(inputUrl.path) -quality 70% -resize 100x100 \(output.path)"
|
||||||
|
let (code, output) = try safeShell(command)
|
||||||
|
if code != 0 {
|
||||||
|
print("Failed to create thumbnail for cap \(cap): \(output)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
print("Failed to read created thumbnail for cap \(cap): \(error)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
let data: Data
|
||||||
|
do {
|
||||||
|
data = try Data(contentsOf: output)
|
||||||
|
} catch {
|
||||||
|
print("Failed to read created thumbnail for cap \(cap): \(error)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
guard await post(url: server.appendingPathComponent("thumbnails/\(cap)"), body: data) else {
|
||||||
|
print("Failed to upload thumbnail for cap \(cap)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: Helper
|
||||||
|
|
||||||
|
@discardableResult
|
||||||
|
private func safeShell(_ command: String) throws -> (code: Int32, output: String) {
|
||||||
|
let task = Process()
|
||||||
|
let pipe = Pipe()
|
||||||
|
|
||||||
|
task.standardOutput = pipe
|
||||||
|
task.standardError = pipe
|
||||||
|
task.arguments = ["-cl", command]
|
||||||
|
task.executableURL = URL(fileURLWithPath: "/bin/zsh")
|
||||||
|
task.standardInput = nil
|
||||||
|
|
||||||
|
try task.run()
|
||||||
|
task.waitUntilExit()
|
||||||
|
|
||||||
|
let data = pipe.fileHandleForReading.readDataToEndOfFile()
|
||||||
|
let output = String(data: data, encoding: .utf8)!
|
||||||
|
return (task.terminationStatus, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private func createFolderIfMissing(_ folder: URL) -> Bool {
|
||||||
|
guard !FileManager.default.fileExists(atPath: folder.path) else {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
do {
|
||||||
|
try FileManager.default.createDirectory(at: folder, withIntermediateDirectories: true)
|
||||||
|
return true
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to create directory \(folder.path): \(error)")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: Requests
|
||||||
|
|
||||||
|
private func post(url: URL, body: Data) async -> Bool {
|
||||||
|
var request = URLRequest(url: url)
|
||||||
|
request.httpMethod = "POST"
|
||||||
|
request.httpBody = body
|
||||||
|
request.addValue(configuration.authenticationToken, forHTTPHeaderField: "key")
|
||||||
|
return await perform(request) != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
private func perform(_ request: URLRequest) async -> Data? {
|
||||||
|
let data: Data
|
||||||
|
let response: URLResponse
|
||||||
|
do {
|
||||||
|
(data, response) = try await URLSession.shared.data(for: request)
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Request to \(request.url!.absoluteString) failed: \(error)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
let code = (response as! HTTPURLResponse).statusCode
|
||||||
|
guard code == 200 else {
|
||||||
|
print("[ERROR] Request to \(request.url!.absoluteString): Invalid response \(code)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
private func get(_ url: URL) async -> Data? {
|
||||||
|
await perform(URLRequest(url: url))
|
||||||
|
}
|
||||||
|
|
||||||
|
private func get(_ url: URL) async -> String? {
|
||||||
|
guard let data: Data = await get(url) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
guard let string = String(data: data, encoding: .utf8) else {
|
||||||
|
print("[ERROR] Invalid string response \(data)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return string
|
||||||
|
}
|
||||||
|
}
|
40
Sources/Configuration.swift
Normal file
40
Sources/Configuration.swift
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import Foundation
|
||||||
|
|
||||||
|
struct Configuration: Codable {
|
||||||
|
|
||||||
|
let contentFolder: String
|
||||||
|
|
||||||
|
let trainingIterations: Int
|
||||||
|
|
||||||
|
let serverPath: String
|
||||||
|
|
||||||
|
let authenticationToken: String
|
||||||
|
|
||||||
|
init(at url: URL) throws {
|
||||||
|
guard FileManager.default.fileExists(atPath: url.path) else {
|
||||||
|
print("[ERROR] No configuration at \(url.absoluteURL.path)")
|
||||||
|
throw TrainingError.configurationFileMissing
|
||||||
|
}
|
||||||
|
let data: Data
|
||||||
|
do {
|
||||||
|
data = try Data(contentsOf: url)
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to load configuration data at \(url.absoluteURL.path): \(error)")
|
||||||
|
throw TrainingError.configurationFileUnreadable
|
||||||
|
}
|
||||||
|
do {
|
||||||
|
self = try JSONDecoder().decode(Configuration.self, from: data)
|
||||||
|
} catch {
|
||||||
|
print("[ERROR] Failed to decode configuration at \(url.absoluteURL.path): \(error)")
|
||||||
|
throw TrainingError.configurationFileDecodingFailed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverUrl() throws -> URL {
|
||||||
|
guard let serverUrl = URL(string: serverPath) else {
|
||||||
|
print("[ERROR] Configuration: Invalid server path \(serverPath)")
|
||||||
|
throw TrainingError.invalidServerPath
|
||||||
|
}
|
||||||
|
return serverUrl
|
||||||
|
}
|
||||||
|
}
|
27
Sources/TrainingError.swift
Normal file
27
Sources/TrainingError.swift
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import Foundation
|
||||||
|
|
||||||
|
enum TrainingError: Error {
|
||||||
|
|
||||||
|
case configurationFileMissing
|
||||||
|
case configurationFileUnreadable
|
||||||
|
case configurationFileDecodingFailed
|
||||||
|
case invalidServerPath
|
||||||
|
case mainImageFolderNotCreated
|
||||||
|
case failedToLoadImages
|
||||||
|
|
||||||
|
case failedToGetListOfImageFolders
|
||||||
|
case failedToRemoveImageFolder
|
||||||
|
|
||||||
|
case failedToGetClassifierVersion
|
||||||
|
case invalidClassifierVersion(String)
|
||||||
|
|
||||||
|
case failedToCreateClassifier(Error)
|
||||||
|
case failedToWriteClassifier(Error)
|
||||||
|
|
||||||
|
case failedToGetChangedImagesList
|
||||||
|
|
||||||
|
case failedToReadClassifierData(Error)
|
||||||
|
case failedToUploadClassifier
|
||||||
|
|
||||||
|
case failedToUploadClassifierClasses
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user