Add classifier update routes
This commit is contained in:
parent
a15b6754fb
commit
1f08a77eaf
@ -12,12 +12,34 @@ final class CapServer {
|
||||
|
||||
private let classifierVersionFile: URL
|
||||
|
||||
private let classifierFile: URL
|
||||
|
||||
private let fm = FileManager.default
|
||||
|
||||
// MARK: Caps
|
||||
|
||||
private var writers: Set<String>
|
||||
|
||||
var classifierVersion: Int {
|
||||
get {
|
||||
do {
|
||||
let content = try String(contentsOf: classifierVersionFile)
|
||||
return Int(content) ?? 0
|
||||
} catch {
|
||||
log("Failed to read classifier version file: \(error)")
|
||||
return 0
|
||||
}
|
||||
}
|
||||
set {
|
||||
do {
|
||||
try "\(newValue)".data(using: .utf8)!
|
||||
.write(to: classifierVersionFile)
|
||||
} catch {
|
||||
log("Failed to save classifier version: \(error)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
The time to wait for changes to be written to disk.
|
||||
|
||||
@ -44,6 +66,7 @@ final class CapServer {
|
||||
self.imageFolder = folder.appendingPathComponent("images")
|
||||
self.dbFile = folder.appendingPathComponent("caps.json")
|
||||
self.classifierVersionFile = folder.appendingPathComponent("classifier.version")
|
||||
self.classifierFile = folder.appendingPathComponent("classifier.mlmodel")
|
||||
self.writers = Set(writers)
|
||||
|
||||
var isDirectory: ObjCBool = false
|
||||
@ -216,25 +239,11 @@ final class CapServer {
|
||||
log("Updated cap \(existingCap.id)")
|
||||
}
|
||||
|
||||
func updateClassifierCaps(from url: URL) {
|
||||
guard fm.fileExists(atPath: url.path) else {
|
||||
return
|
||||
}
|
||||
let content: String
|
||||
do {
|
||||
content = try String(contentsOf: url)
|
||||
} catch {
|
||||
log("Failed to read classifier training result file: \(error)")
|
||||
return
|
||||
}
|
||||
guard let version = readClassifierVersionFromDisk() else {
|
||||
return
|
||||
}
|
||||
|
||||
func updateTrainedClasses(content: String) {
|
||||
let trainedCaps = content
|
||||
.components(separatedBy: "\n")
|
||||
.compactMap(Int.init)
|
||||
|
||||
let version = classifierVersion
|
||||
for cap in trainedCaps {
|
||||
if caps[cap]?.classifierVersion == nil {
|
||||
caps[cap]?.classifierVersion = version
|
||||
@ -242,13 +251,13 @@ final class CapServer {
|
||||
}
|
||||
}
|
||||
|
||||
private func readClassifierVersionFromDisk() -> Int? {
|
||||
func save(classifier: Data, version: Int) throws {
|
||||
do {
|
||||
let content = try String(contentsOf: classifierVersionFile)
|
||||
return Int(content)
|
||||
try classifier.write(to: classifierFile)
|
||||
} catch {
|
||||
log("Failed to read classifier version file: \(error)")
|
||||
return nil
|
||||
log("Failed to write classifier: \(error)")
|
||||
throw Abort(.internalServerError)
|
||||
}
|
||||
classifierVersion = version
|
||||
}
|
||||
}
|
||||
|
@ -14,6 +14,10 @@ private func authorize(_ request: Request) throws {
|
||||
|
||||
func routes(_ app: Application) throws {
|
||||
|
||||
app.getCatching("version") { _ in
|
||||
return "\(server.classifierVersion)"
|
||||
}
|
||||
|
||||
// Add or change a cap
|
||||
app.postCatching("cap") { request in
|
||||
try authorize(request)
|
||||
@ -40,10 +44,36 @@ func routes(_ app: Application) throws {
|
||||
try server.save(image: data, for: cap)
|
||||
}
|
||||
|
||||
// Update the classifier versions from file on disk
|
||||
app.getCatching("refresh") { _ in
|
||||
let url = URL(fileURLWithPath: app.directory.resourcesDirectory)
|
||||
.appendingPathComponent("classifier.trained")
|
||||
server.updateClassifierCaps(from: url)
|
||||
// Update the classifier
|
||||
app.postCatching("classifier", ":version") { request in
|
||||
try authorize(request)
|
||||
guard let version = request.parameters.get("version", as: Int.self) else {
|
||||
log("Invalid parameter for version")
|
||||
throw Abort(.badRequest)
|
||||
}
|
||||
guard version > server.classifierVersion else {
|
||||
throw Abort(.alreadyReported)
|
||||
}
|
||||
guard let buffer = request.body.data else {
|
||||
log("Missing body data: \(request.body.description)")
|
||||
throw CapError.invalidBody
|
||||
}
|
||||
let data = Data(buffer: buffer)
|
||||
try server.save(image: data, for: version)
|
||||
}
|
||||
|
||||
// Update the trained classes
|
||||
app.postCatching("classes") { request in
|
||||
try authorize(request)
|
||||
guard let buffer = request.body.data else {
|
||||
log("Missing body data: \(request.body.description)")
|
||||
throw CapError.invalidBody
|
||||
}
|
||||
let data = Data(buffer: buffer)
|
||||
guard let content = String(data: data, encoding: .utf8) else {
|
||||
log("Invalid string body: \(request.body.description)")
|
||||
throw CapError.invalidBody
|
||||
}
|
||||
server.updateTrainedClasses(content: content)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user