Improve progress updating, fix warnings

This commit is contained in:
Christoph Hagen
2025-01-31 14:52:20 +01:00
parent 3dc5674a3a
commit 92e1eacf57
9 changed files with 154 additions and 144 deletions

View File

@@ -3,6 +3,7 @@ import SwiftUI
import Vision
import CryptoKit
@MainActor
final class Database: ObservableObject {
@AppStorage("classifier")
@@ -33,6 +34,8 @@ final class Database: ObservableObject {
private let encoder = JSONEncoder()
private let decoder = JSONDecoder()
private let progressObserver = ClassifierProgressObserver()
@AppStorage("serverUrl")
var serverPath: String = "" {
didSet {
@@ -150,9 +153,7 @@ final class Database: ObservableObject {
}
set {
_classifierClassesCache = newValue
DispatchQueue.main.async {
self._classifierClassesString = newValue.map { "\($0)" }.joined(separator: ",")
}
self._classifierClassesString = newValue.map { "\($0)" }.joined(separator: ",")
}
}
@@ -187,6 +188,11 @@ final class Database: ObservableObject {
updatePendingImageUploadCount(imageUploads: imageUploads)
}
convenience init(caps: [Int : Cap]) {
self.init()
self.caps = caps
}
func mainImage(for cap: Int) -> Int {
caps[cap]?.mainImage ?? 0
}
@@ -358,9 +364,7 @@ final class Database: ObservableObject {
} else {
oldCap.update(with: cap)
let save = oldCap
DispatchQueue.main.async {
self.caps[cap.id] = save
}
self.caps[cap.id] = save
updates += 1
}
}
@@ -394,9 +398,7 @@ final class Database: ObservableObject {
log("Classifier version has an invalid value '\(string)'")
return false
}
DispatchQueue.main.async {
self.storedServerClassifierVersion = serverVersion
}
self.storedServerClassifierVersion = serverVersion
return true
}
@@ -409,19 +411,16 @@ final class Database: ObservableObject {
log("Downloading classifier")
let progress = ClassifierProgress()
DispatchQueue.main.async {
self.classifierDownloadProgress = progress
}
self.classifierDownloadProgress = progress
defer {
DispatchQueue.main.async {
self.classifierDownloadProgress = nil
}
self.classifierDownloadProgress = nil
}
let tempUrl: URL
let response: URLResponse
do {
(tempUrl, response) = try await URLSession.shared.download(from: serverClassifierUrl, delegate: progress)
progressObserver.delegate = self
(tempUrl, response) = try await URLSession.shared.download(from: serverClassifierUrl, delegate: progressObserver)
} catch {
log("Failed to download classifier version: \(error)")
return false
@@ -439,14 +438,14 @@ final class Database: ObservableObject {
log("Failed to replace classifier: \(error)")
return false
}
DispatchQueue.main.async {
self.storedLocalClassifierVersion = self.serverClassifierVersion
log("Downloaded classifier \(self.localClassifierVersion)")
self.classifier = nil
}
self.storedLocalClassifierVersion = self.serverClassifierVersion
log("Downloaded classifier \(self.localClassifierVersion)")
self.classifier = nil
return true
}
@discardableResult
func downloadClassifierClasses() async -> Bool {
guard let serverClassifierClassesUrl else {
@@ -509,9 +508,7 @@ final class Database: ObservableObject {
func save(newCap name: String) -> Cap {
let cap = Cap(id: nextCapId, name: name)
caps[cap.id] = cap
DispatchQueue.main.async {
self.changedCaps.insert(cap.id)
}
self.changedCaps.insert(cap.id)
return cap
}
@@ -526,17 +523,11 @@ final class Database: ObservableObject {
}
log("Saved image \(cap.imageCount) for cap \(capId)")
if imageUploads[capId] != nil {
DispatchQueue.main.async {
self.imageUploads[capId]!.append(cap.imageCount)
}
self.imageUploads[capId]!.append(cap.imageCount)
} else {
DispatchQueue.main.async {
self.imageUploads[capId] = [cap.imageCount]
}
}
DispatchQueue.main.async {
self.caps[capId]!.imageCount += 1
self.imageUploads[capId] = [cap.imageCount]
}
self.caps[capId]!.imageCount += 1
return true
}
@@ -558,14 +549,10 @@ final class Database: ObservableObject {
return
}
log("Starting upload timer")
DispatchQueue.main.async {
self.uploadTimer = Timer.scheduledTimer(withTimeInterval: 5, repeats: true, block: self.uploadTimerElapsed)
}
}
private func uploadTimerElapsed(timer: Timer) {
Task {
await uploadAll()
self.uploadTimer = Timer.scheduledTimer(withTimeInterval: 5, repeats: true) { _ in
Task {
await self.uploadAll()
}
}
}
@@ -574,22 +561,16 @@ final class Database: ObservableObject {
log("Already uploading")
return
}
DispatchQueue.main.async {
self.isUploading = true
}
self.isUploading = true
defer {
DispatchQueue.main.async {
self.isUploading = false
}
self.isUploading = false
}
guard !changedCaps.isEmpty || pendingImageUploadCount > 0 else {
return
}
log("Starting uploads")
let uploaded = await uploadAllChangedCaps()
DispatchQueue.main.async {
self.changedCaps.subtract(uploaded)
}
self.changedCaps.subtract(uploaded)
await uploadAllImages()
log("Uploads finished")
}
@@ -622,14 +603,10 @@ final class Database: ObservableObject {
}
log("Uploaded image \(image) for cap \(cap)")
let remaining = imageUploads[cap]?.filter { $0 != image }
if let r = remaining, !r.isEmpty {
DispatchQueue.main.async {
self.imageUploads[cap] = r
}
if let remaining, !remaining.isEmpty {
self.imageUploads[cap] = remaining
} else {
DispatchQueue.main.async {
self.imageUploads[cap] = nil
}
self.imageUploads[cap] = nil
}
}
}
@@ -664,9 +641,7 @@ final class Database: ObservableObject {
if httpResponse.statusCode == 410 {
log("Missing cap for image \(url.lastPathComponent), reupload cap")
// Missing cap, force upload
DispatchQueue.main.async {
self.changedCaps.insert(cap)
}
self.changedCaps.insert(cap)
} else {
log("Failed to upload image \(url.lastPathComponent): Response \(httpResponse.statusCode)")
}
@@ -736,9 +711,7 @@ final class Database: ObservableObject {
log("Failed to upload cap \(cap.id): Response \(httpResponse.statusCode)")
return false
}
DispatchQueue.main.async {
self.changedCaps.remove(cap.id)
}
self.changedCaps.remove(cap.id)
return true
} catch {
log("Failed to upload cap \(cap.id): \(error)")
@@ -761,10 +734,8 @@ final class Database: ObservableObject {
}
cap.mainImage = version
let finalCap = cap
DispatchQueue.main.async {
self.caps[capId] = finalCap
log("Set main image \(version) for \(capId)")
}
self.caps[capId] = finalCap
log("Set main image \(version) for \(capId)")
return finalCap
}
@@ -852,9 +823,7 @@ final class Database: ObservableObject {
// Delete cached images
images.removeCachedImages(for: cap)
// Delete cap
DispatchQueue.main.async {
self.caps[cap] = nil
}
self.caps[cap] = nil
log("Deleted cap \(cap)")
return true
} catch {
@@ -888,19 +857,24 @@ final class Database: ObservableObject {
log("Image removed")
return
}
DispatchQueue.global().async {
Task {
guard let classifier = self.getClassifier() else {
return
}
log("Image classification started")
classifier.recognize(image: image) { matches in
DispatchQueue.main.async {
self.matches = matches ?? [:]
let matches = await withCheckedContinuation { continuation in
classifier.recognize(image: image) { matches in
continuation.resume(returning: matches)
}
}
self.update(matches: matches ?? [:])
}
}
func update(matches: [Int : Float]) {
self.matches = matches
}
func canClassify(cap id: Int) -> Bool {
classifierClasses.contains(id)
}
@@ -1027,68 +1001,13 @@ final class Database: ObservableObject {
}
}
extension Database {
final class ClassifierProgress: NSObject, ObservableObject {
@Published
var bytesLoaded: Double
@Published
var total: Double
var percentage: Double {
guard total > 0 else {
return 0.0
extension Database: ClassifierProgressDelegate {
nonisolated func classifierProgress(_ progress: ClassifierProgress) {
Task {
await MainActor.run {
self.classifierDownloadProgress = progress
}
return bytesLoaded * 100 / total
}
init(bytesLoaded: Double = 0, total: Double = 0) {
self.bytesLoaded = bytesLoaded
self.total = total
}
}
}
extension Database.ClassifierProgress: URLSessionDownloadDelegate {
func urlSession(_ session: URLSession, downloadTask: URLSessionDownloadTask, didFinishDownloadingTo location: URL) {
}
func urlSession(_ session: URLSession, downloadTask: URLSessionDownloadTask, didWriteData bytesWritten: Int64, totalBytesWritten: Int64, totalBytesExpectedToWrite: Int64) {
DispatchQueue.main.async {
self.bytesLoaded = Double(totalBytesWritten)
self.total = Double(totalBytesExpectedToWrite)
}
}
}
extension Database {
static var mock: Database {
let db = Database()
db.serverPath = "https://caps.christophhagen.de"
db.caps = [
Cap(id: 123, name: "My new cap"),
Cap(id: 234, name: "My favorite cap"),
Cap(id: 345, name: "My oldest cap"),
Cap(id: 456, name: "My new cap"),
Cap(id: 567, name: "My favorite cap"),
Cap(id: 678, name: "My oldest cap"),
].reduce(into: [:]) { $0[$1.id] = $1 }
db.image = UIImage(systemSymbol: .photo)
return db
}
static var largeMock: Database {
let db = Database()
db.serverPath = "https://caps.christophhagen.de"
db.caps = (1..<500)
.map { Cap(id: $0, name: "Cap \($0)") }
.reduce(into: [:]) { $0[$1.id] = $1 }
db.image = UIImage(systemSymbol: .photo)
return db
}
}