Add classifier update routes

This commit is contained in:
Christoph Hagen 2022-06-23 22:48:58 +02:00
parent a15b6754fb
commit 1f08a77eaf
2 changed files with 65 additions and 26 deletions

View File

@ -12,11 +12,33 @@ final class CapServer {
private let classifierVersionFile: URL private let classifierVersionFile: URL
private let classifierFile: URL
private let fm = FileManager.default private let fm = FileManager.default
// MARK: Caps // MARK: Caps
private var writers: Set<String> private var writers: Set<String>
var classifierVersion: Int {
get {
do {
let content = try String(contentsOf: classifierVersionFile)
return Int(content) ?? 0
} catch {
log("Failed to read classifier version file: \(error)")
return 0
}
}
set {
do {
try "\(newValue)".data(using: .utf8)!
.write(to: classifierVersionFile)
} catch {
log("Failed to save classifier version: \(error)")
}
}
}
/** /**
The time to wait for changes to be written to disk. The time to wait for changes to be written to disk.
@ -44,6 +66,7 @@ final class CapServer {
self.imageFolder = folder.appendingPathComponent("images") self.imageFolder = folder.appendingPathComponent("images")
self.dbFile = folder.appendingPathComponent("caps.json") self.dbFile = folder.appendingPathComponent("caps.json")
self.classifierVersionFile = folder.appendingPathComponent("classifier.version") self.classifierVersionFile = folder.appendingPathComponent("classifier.version")
self.classifierFile = folder.appendingPathComponent("classifier.mlmodel")
self.writers = Set(writers) self.writers = Set(writers)
var isDirectory: ObjCBool = false var isDirectory: ObjCBool = false
@ -216,25 +239,11 @@ final class CapServer {
log("Updated cap \(existingCap.id)") log("Updated cap \(existingCap.id)")
} }
func updateClassifierCaps(from url: URL) { func updateTrainedClasses(content: String) {
guard fm.fileExists(atPath: url.path) else {
return
}
let content: String
do {
content = try String(contentsOf: url)
} catch {
log("Failed to read classifier training result file: \(error)")
return
}
guard let version = readClassifierVersionFromDisk() else {
return
}
let trainedCaps = content let trainedCaps = content
.components(separatedBy: "\n") .components(separatedBy: "\n")
.compactMap(Int.init) .compactMap(Int.init)
let version = classifierVersion
for cap in trainedCaps { for cap in trainedCaps {
if caps[cap]?.classifierVersion == nil { if caps[cap]?.classifierVersion == nil {
caps[cap]?.classifierVersion = version caps[cap]?.classifierVersion = version
@ -242,13 +251,13 @@ final class CapServer {
} }
} }
private func readClassifierVersionFromDisk() -> Int? { func save(classifier: Data, version: Int) throws {
do { do {
let content = try String(contentsOf: classifierVersionFile) try classifier.write(to: classifierFile)
return Int(content)
} catch { } catch {
log("Failed to read classifier version file: \(error)") log("Failed to write classifier: \(error)")
return nil throw Abort(.internalServerError)
} }
classifierVersion = version
} }
} }

View File

@ -13,6 +13,10 @@ private func authorize(_ request: Request) throws {
} }
func routes(_ app: Application) throws { func routes(_ app: Application) throws {
app.getCatching("version") { _ in
return "\(server.classifierVersion)"
}
// Add or change a cap // Add or change a cap
app.postCatching("cap") { request in app.postCatching("cap") { request in
@ -40,10 +44,36 @@ func routes(_ app: Application) throws {
try server.save(image: data, for: cap) try server.save(image: data, for: cap)
} }
// Update the classifier versions from file on disk // Update the classifier
app.getCatching("refresh") { _ in app.postCatching("classifier", ":version") { request in
let url = URL(fileURLWithPath: app.directory.resourcesDirectory) try authorize(request)
.appendingPathComponent("classifier.trained") guard let version = request.parameters.get("version", as: Int.self) else {
server.updateClassifierCaps(from: url) log("Invalid parameter for version")
throw Abort(.badRequest)
}
guard version > server.classifierVersion else {
throw Abort(.alreadyReported)
}
guard let buffer = request.body.data else {
log("Missing body data: \(request.body.description)")
throw CapError.invalidBody
}
let data = Data(buffer: buffer)
try server.save(image: data, for: version)
}
// Update the trained classes
app.postCatching("classes") { request in
try authorize(request)
guard let buffer = request.body.data else {
log("Missing body data: \(request.body.description)")
throw CapError.invalidBody
}
let data = Data(buffer: buffer)
guard let content = String(data: data, encoding: .utf8) else {
log("Invalid string body: \(request.body.description)")
throw CapError.invalidBody
}
server.updateTrainedClasses(content: content)
} }
} }