Add classifier update routes

This commit is contained in:
Christoph Hagen 2022-06-23 22:48:58 +02:00
parent a15b6754fb
commit 1f08a77eaf
2 changed files with 65 additions and 26 deletions

View File

@ -12,12 +12,34 @@ final class CapServer {
private let classifierVersionFile: URL
private let classifierFile: URL
private let fm = FileManager.default
// MARK: Caps
private var writers: Set<String>
var classifierVersion: Int {
get {
do {
let content = try String(contentsOf: classifierVersionFile)
return Int(content) ?? 0
} catch {
log("Failed to read classifier version file: \(error)")
return 0
}
}
set {
do {
try "\(newValue)".data(using: .utf8)!
.write(to: classifierVersionFile)
} catch {
log("Failed to save classifier version: \(error)")
}
}
}
/**
The time to wait for changes to be written to disk.
@ -44,6 +66,7 @@ final class CapServer {
self.imageFolder = folder.appendingPathComponent("images")
self.dbFile = folder.appendingPathComponent("caps.json")
self.classifierVersionFile = folder.appendingPathComponent("classifier.version")
self.classifierFile = folder.appendingPathComponent("classifier.mlmodel")
self.writers = Set(writers)
var isDirectory: ObjCBool = false
@ -216,25 +239,11 @@ final class CapServer {
log("Updated cap \(existingCap.id)")
}
func updateClassifierCaps(from url: URL) {
guard fm.fileExists(atPath: url.path) else {
return
}
let content: String
do {
content = try String(contentsOf: url)
} catch {
log("Failed to read classifier training result file: \(error)")
return
}
guard let version = readClassifierVersionFromDisk() else {
return
}
func updateTrainedClasses(content: String) {
let trainedCaps = content
.components(separatedBy: "\n")
.compactMap(Int.init)
let version = classifierVersion
for cap in trainedCaps {
if caps[cap]?.classifierVersion == nil {
caps[cap]?.classifierVersion = version
@ -242,13 +251,13 @@ final class CapServer {
}
}
private func readClassifierVersionFromDisk() -> Int? {
func save(classifier: Data, version: Int) throws {
do {
let content = try String(contentsOf: classifierVersionFile)
return Int(content)
try classifier.write(to: classifierFile)
} catch {
log("Failed to read classifier version file: \(error)")
return nil
log("Failed to write classifier: \(error)")
throw Abort(.internalServerError)
}
classifierVersion = version
}
}

View File

@ -14,6 +14,10 @@ private func authorize(_ request: Request) throws {
func routes(_ app: Application) throws {
app.getCatching("version") { _ in
return "\(server.classifierVersion)"
}
// Add or change a cap
app.postCatching("cap") { request in
try authorize(request)
@ -40,10 +44,36 @@ func routes(_ app: Application) throws {
try server.save(image: data, for: cap)
}
// Update the classifier versions from file on disk
app.getCatching("refresh") { _ in
let url = URL(fileURLWithPath: app.directory.resourcesDirectory)
.appendingPathComponent("classifier.trained")
server.updateClassifierCaps(from: url)
// Update the classifier
app.postCatching("classifier", ":version") { request in
try authorize(request)
guard let version = request.parameters.get("version", as: Int.self) else {
log("Invalid parameter for version")
throw Abort(.badRequest)
}
guard version > server.classifierVersion else {
throw Abort(.alreadyReported)
}
guard let buffer = request.body.data else {
log("Missing body data: \(request.body.description)")
throw CapError.invalidBody
}
let data = Data(buffer: buffer)
try server.save(image: data, for: version)
}
// Update the trained classes
app.postCatching("classes") { request in
try authorize(request)
guard let buffer = request.body.data else {
log("Missing body data: \(request.body.description)")
throw CapError.invalidBody
}
let data = Data(buffer: buffer)
guard let content = String(data: data, encoding: .utf8) else {
log("Invalid string body: \(request.body.description)")
throw CapError.invalidBody
}
server.updateTrainedClasses(content: content)
}
}