Add classifier update routes
This commit is contained in:
parent
a15b6754fb
commit
1f08a77eaf
@ -12,11 +12,33 @@ final class CapServer {
|
|||||||
|
|
||||||
private let classifierVersionFile: URL
|
private let classifierVersionFile: URL
|
||||||
|
|
||||||
|
private let classifierFile: URL
|
||||||
|
|
||||||
private let fm = FileManager.default
|
private let fm = FileManager.default
|
||||||
|
|
||||||
// MARK: Caps
|
// MARK: Caps
|
||||||
|
|
||||||
private var writers: Set<String>
|
private var writers: Set<String>
|
||||||
|
|
||||||
|
var classifierVersion: Int {
|
||||||
|
get {
|
||||||
|
do {
|
||||||
|
let content = try String(contentsOf: classifierVersionFile)
|
||||||
|
return Int(content) ?? 0
|
||||||
|
} catch {
|
||||||
|
log("Failed to read classifier version file: \(error)")
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
set {
|
||||||
|
do {
|
||||||
|
try "\(newValue)".data(using: .utf8)!
|
||||||
|
.write(to: classifierVersionFile)
|
||||||
|
} catch {
|
||||||
|
log("Failed to save classifier version: \(error)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
The time to wait for changes to be written to disk.
|
The time to wait for changes to be written to disk.
|
||||||
@ -44,6 +66,7 @@ final class CapServer {
|
|||||||
self.imageFolder = folder.appendingPathComponent("images")
|
self.imageFolder = folder.appendingPathComponent("images")
|
||||||
self.dbFile = folder.appendingPathComponent("caps.json")
|
self.dbFile = folder.appendingPathComponent("caps.json")
|
||||||
self.classifierVersionFile = folder.appendingPathComponent("classifier.version")
|
self.classifierVersionFile = folder.appendingPathComponent("classifier.version")
|
||||||
|
self.classifierFile = folder.appendingPathComponent("classifier.mlmodel")
|
||||||
self.writers = Set(writers)
|
self.writers = Set(writers)
|
||||||
|
|
||||||
var isDirectory: ObjCBool = false
|
var isDirectory: ObjCBool = false
|
||||||
@ -216,25 +239,11 @@ final class CapServer {
|
|||||||
log("Updated cap \(existingCap.id)")
|
log("Updated cap \(existingCap.id)")
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateClassifierCaps(from url: URL) {
|
func updateTrainedClasses(content: String) {
|
||||||
guard fm.fileExists(atPath: url.path) else {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
let content: String
|
|
||||||
do {
|
|
||||||
content = try String(contentsOf: url)
|
|
||||||
} catch {
|
|
||||||
log("Failed to read classifier training result file: \(error)")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
guard let version = readClassifierVersionFromDisk() else {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
let trainedCaps = content
|
let trainedCaps = content
|
||||||
.components(separatedBy: "\n")
|
.components(separatedBy: "\n")
|
||||||
.compactMap(Int.init)
|
.compactMap(Int.init)
|
||||||
|
let version = classifierVersion
|
||||||
for cap in trainedCaps {
|
for cap in trainedCaps {
|
||||||
if caps[cap]?.classifierVersion == nil {
|
if caps[cap]?.classifierVersion == nil {
|
||||||
caps[cap]?.classifierVersion = version
|
caps[cap]?.classifierVersion = version
|
||||||
@ -242,13 +251,13 @@ final class CapServer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private func readClassifierVersionFromDisk() -> Int? {
|
func save(classifier: Data, version: Int) throws {
|
||||||
do {
|
do {
|
||||||
let content = try String(contentsOf: classifierVersionFile)
|
try classifier.write(to: classifierFile)
|
||||||
return Int(content)
|
|
||||||
} catch {
|
} catch {
|
||||||
log("Failed to read classifier version file: \(error)")
|
log("Failed to write classifier: \(error)")
|
||||||
return nil
|
throw Abort(.internalServerError)
|
||||||
}
|
}
|
||||||
|
classifierVersion = version
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,10 @@ private func authorize(_ request: Request) throws {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func routes(_ app: Application) throws {
|
func routes(_ app: Application) throws {
|
||||||
|
|
||||||
|
app.getCatching("version") { _ in
|
||||||
|
return "\(server.classifierVersion)"
|
||||||
|
}
|
||||||
|
|
||||||
// Add or change a cap
|
// Add or change a cap
|
||||||
app.postCatching("cap") { request in
|
app.postCatching("cap") { request in
|
||||||
@ -40,10 +44,36 @@ func routes(_ app: Application) throws {
|
|||||||
try server.save(image: data, for: cap)
|
try server.save(image: data, for: cap)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the classifier versions from file on disk
|
// Update the classifier
|
||||||
app.getCatching("refresh") { _ in
|
app.postCatching("classifier", ":version") { request in
|
||||||
let url = URL(fileURLWithPath: app.directory.resourcesDirectory)
|
try authorize(request)
|
||||||
.appendingPathComponent("classifier.trained")
|
guard let version = request.parameters.get("version", as: Int.self) else {
|
||||||
server.updateClassifierCaps(from: url)
|
log("Invalid parameter for version")
|
||||||
|
throw Abort(.badRequest)
|
||||||
|
}
|
||||||
|
guard version > server.classifierVersion else {
|
||||||
|
throw Abort(.alreadyReported)
|
||||||
|
}
|
||||||
|
guard let buffer = request.body.data else {
|
||||||
|
log("Missing body data: \(request.body.description)")
|
||||||
|
throw CapError.invalidBody
|
||||||
|
}
|
||||||
|
let data = Data(buffer: buffer)
|
||||||
|
try server.save(image: data, for: version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the trained classes
|
||||||
|
app.postCatching("classes") { request in
|
||||||
|
try authorize(request)
|
||||||
|
guard let buffer = request.body.data else {
|
||||||
|
log("Missing body data: \(request.body.description)")
|
||||||
|
throw CapError.invalidBody
|
||||||
|
}
|
||||||
|
let data = Data(buffer: buffer)
|
||||||
|
guard let content = String(data: data, encoding: .utf8) else {
|
||||||
|
log("Invalid string body: \(request.body.description)")
|
||||||
|
throw CapError.invalidBody
|
||||||
|
}
|
||||||
|
server.updateTrainedClasses(content: content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user