diff --git a/Sources/App/CapServer.swift b/Sources/App/CapServer.swift index 2c77fd2..98b8561 100644 --- a/Sources/App/CapServer.swift +++ b/Sources/App/CapServer.swift @@ -12,11 +12,33 @@ final class CapServer { private let classifierVersionFile: URL + private let classifierFile: URL + private let fm = FileManager.default // MARK: Caps private var writers: Set + + var classifierVersion: Int { + get { + do { + let content = try String(contentsOf: classifierVersionFile) + return Int(content) ?? 0 + } catch { + log("Failed to read classifier version file: \(error)") + return 0 + } + } + set { + do { + try "\(newValue)".data(using: .utf8)! + .write(to: classifierVersionFile) + } catch { + log("Failed to save classifier version: \(error)") + } + } + } /** The time to wait for changes to be written to disk. @@ -44,6 +66,7 @@ final class CapServer { self.imageFolder = folder.appendingPathComponent("images") self.dbFile = folder.appendingPathComponent("caps.json") self.classifierVersionFile = folder.appendingPathComponent("classifier.version") + self.classifierFile = folder.appendingPathComponent("classifier.mlmodel") self.writers = Set(writers) var isDirectory: ObjCBool = false @@ -216,25 +239,11 @@ final class CapServer { log("Updated cap \(existingCap.id)") } - func updateClassifierCaps(from url: URL) { - guard fm.fileExists(atPath: url.path) else { - return - } - let content: String - do { - content = try String(contentsOf: url) - } catch { - log("Failed to read classifier training result file: \(error)") - return - } - guard let version = readClassifierVersionFromDisk() else { - return - } - + func updateTrainedClasses(content: String) { let trainedCaps = content .components(separatedBy: "\n") .compactMap(Int.init) - + let version = classifierVersion for cap in trainedCaps { if caps[cap]?.classifierVersion == nil { caps[cap]?.classifierVersion = version @@ -242,13 +251,13 @@ final class CapServer { } } - private func readClassifierVersionFromDisk() -> Int? { + func save(classifier: Data, version: Int) throws { do { - let content = try String(contentsOf: classifierVersionFile) - return Int(content) + try classifier.write(to: classifierFile) } catch { - log("Failed to read classifier version file: \(error)") - return nil + log("Failed to write classifier: \(error)") + throw Abort(.internalServerError) } + classifierVersion = version } } diff --git a/Sources/App/routes.swift b/Sources/App/routes.swift index c12cc0c..d9fb331 100755 --- a/Sources/App/routes.swift +++ b/Sources/App/routes.swift @@ -13,6 +13,10 @@ private func authorize(_ request: Request) throws { } func routes(_ app: Application) throws { + + app.getCatching("version") { _ in + return "\(server.classifierVersion)" + } // Add or change a cap app.postCatching("cap") { request in @@ -40,10 +44,36 @@ func routes(_ app: Application) throws { try server.save(image: data, for: cap) } - // Update the classifier versions from file on disk - app.getCatching("refresh") { _ in - let url = URL(fileURLWithPath: app.directory.resourcesDirectory) - .appendingPathComponent("classifier.trained") - server.updateClassifierCaps(from: url) + // Update the classifier + app.postCatching("classifier", ":version") { request in + try authorize(request) + guard let version = request.parameters.get("version", as: Int.self) else { + log("Invalid parameter for version") + throw Abort(.badRequest) + } + guard version > server.classifierVersion else { + throw Abort(.alreadyReported) + } + guard let buffer = request.body.data else { + log("Missing body data: \(request.body.description)") + throw CapError.invalidBody + } + let data = Data(buffer: buffer) + try server.save(image: data, for: version) + } + + // Update the trained classes + app.postCatching("classes") { request in + try authorize(request) + guard let buffer = request.body.data else { + log("Missing body data: \(request.body.description)") + throw CapError.invalidBody + } + let data = Data(buffer: buffer) + guard let content = String(data: data, encoding: .utf8) else { + log("Invalid string body: \(request.body.description)") + throw CapError.invalidBody + } + server.updateTrainedClasses(content: content) } }