Compare commits

...

3 Commits

Author SHA1 Message Date
Christoph Hagen
dec1154006 Exclude caps with too few images 2024-10-28 19:06:45 +01:00
Christoph Hagen
93af0cb76d Don't cache web requests 2024-06-12 16:14:05 +02:00
Christoph Hagen
58b1dcb277 Allow skipping of training 2024-06-12 16:13:42 +02:00
6 changed files with 67 additions and 13 deletions

View File

@@ -3,4 +3,5 @@
"iterations": 20,
"server": "https://mydomain.com/caps",
"authentication": "mysecretkey",
"minimumImagesPerCap": 10
}

View File

@@ -6,6 +6,8 @@ struct CapTrain: AsyncParsableCommand {
private static let defaultIterations = 10
private static let defaultMinimumImageCount = 10
@Flag(name: .shortAndLong, help: "Resume the previous training session (default: false)")
var resume: Bool = false
@@ -15,6 +17,9 @@ struct CapTrain: AsyncParsableCommand {
@Option(name: .shortAndLong, help: "The number of iterations to train (default: 10)")
var iterations: Int?
@Option(name: .shortAndLong, help: "The minimum number of images for a cap to be included in training (default: 10)")
var minimumImageCount: Int?
@Option(name: .shortAndLong, help: "The url of the caps server to retrieve images and upload the classifier")
var server: String?
@@ -23,10 +28,14 @@ struct CapTrain: AsyncParsableCommand {
@Option(name: .shortAndLong, help: "The path to the folder where the content (images, classifier, thumbnails) is stored")
var folder: String?
@Flag(name: .long, help: "Only upload the previous classifier")
var skipTraining: Bool = false
func run() async throws {
let configurationFile = try configurationFile()
let iterations = iterations ?? configurationFile?.iterations ?? CapTrain.defaultIterations
let minimumCount = minimumImageCount ?? configurationFile?.minimumImagesPerCap ?? CapTrain.defaultMinimumImageCount
guard let contentFolder = folder ?? configurationFile?.folder else {
throw TrainingError.missingArguments("folder")
}
@@ -36,15 +45,15 @@ struct CapTrain: AsyncParsableCommand {
guard let authentication = authentication ?? configurationFile?.authentication else {
throw TrainingError.missingArguments("authentication")
}
let configuration = Configuration(
contentFolder: contentFolder,
trainingIterations: iterations,
serverPath: serverPath,
authenticationToken: authentication)
authenticationToken: authentication,
minimumImagesPerCap: minimumCount)
let creator = try ClassifierCreator(configuration: configuration, resume: resume)
try await creator.run()
try await creator.run(skipTraining: skipTraining)
}
private func configurationFile() throws -> ConfigurationFile? {

View File

@@ -15,6 +15,9 @@ final class ClassifierCreator {
let configuration: Configuration
/// The number of images required to include a cap in training
let minimumImagesPerCap: Int
let imageDirectory: URL
let thumbnailDirectory: URL
@@ -23,6 +26,8 @@ final class ClassifierCreator {
let classifierUrl: URL
let urlSession = URLSession(configuration: .ephemeral)
let df = DateFormatter()
private func print(info: String) {
@@ -34,6 +39,7 @@ final class ClassifierCreator {
init(configuration: Configuration, resume: Bool) throws {
self.configuration = configuration
self.server = try configuration.serverUrl()
self.minimumImagesPerCap = configuration.minimumImagesPerCap
let contentDirectory = URL(fileURLWithPath: configuration.contentFolder)
self.imageDirectory = contentDirectory.appendingPathComponent("images")
self.sessionDirectory = contentDirectory.appendingPathComponent("session")
@@ -59,7 +65,7 @@ final class ClassifierCreator {
// MARK: Main function
func run() async throws {
func run(skipTraining: Bool) async throws {
let imagesSnapshotDate = Date()
let (classes, changedImageCount, changedMainImages) = try await loadImages()
@@ -78,13 +84,16 @@ final class ClassifierCreator {
let classifierVersion = try await getClassifierVersion()
let newVersion = classifierVersion + 1
print(info: "Server: \(server.path)")
print(info: "Image directory: \(imageDirectory.absoluteURL.path)")
print(info: "Model path: \(classifierUrl.path)")
print(info: "Version: \(newVersion)")
print(info: "Classes: \(classes.count)")
print(info: "Iterations: \(configuration.trainingIterations)")
try await trainAndSaveModel()
if !skipTraining {
try await trainAndSaveModel()
}
try await uploadModel(version: newVersion)
try await upload(classes: classes, lastUpdate: imagesSnapshotDate)
try await createThumbnails(changed: changedMainImages)
@@ -101,7 +110,16 @@ final class ClassifierCreator {
throw TrainingError.mainImageFolderNotCreated(error)
}
let imageCounts = try await getImageCounts()
.filter { id, count in
// Delete caps with small counts to ensure proper training
if count < self.minimumImagesPerCap {
print(info: "Excluding cap \(id) from training (\(count) images)")
return false
}
return true
}
let missingImageList: [CapImage] = imageCounts
.filter { $0.value >= self.minimumImagesPerCap }
.sorted { $0.key < $1.key }
.reduce(into: []) { list, pair in
let missingImagesForCap: [CapImage] = (0..<pair.value).compactMap { image in
@@ -118,7 +136,7 @@ final class ClassifierCreator {
print(info: "No missing images to load")
} else {
print(info: "Loading \(missingImageList.count) missing images...")
try await loadImages(missingImageList)
try await loadImagesInBatches(missingImageList)
}
@@ -133,7 +151,7 @@ final class ClassifierCreator {
print(info: "No changed images to load" + suffix)
} else {
print(info: "Loading \(filteredChangeList.count) changed images" + suffix)
try await loadImages(filteredChangeList)
try await loadImagesInBatches(filteredChangeList)
}
let changedMainImages = changedImageList.filter { $0.image == 0 }.map { $0.cap }
@@ -178,15 +196,20 @@ final class ClassifierCreator {
let url = imageDirectory.appendingPathComponent(folder)
do {
try FileManager.default.removeItem(at: url)
print(info: "Removed unused image folder '\(folder)'")
print(info: "Removed item in image folder: '\(folder)'")
} catch {
throw TrainingError.failedToRemoveImageFolder(folder, error)
}
}
}
private func capFolderUrl(base: URL, cap: Int) -> URL {
base.appendingPathComponent(String(format: "%04d", cap))
}
private func imageUrl(base: URL, image: CapImage) -> URL {
base.appendingPathComponent(String(format: "%04d/%04d-%02d.jpg", image.cap, image.cap, image.image))
capFolderUrl(base: base, cap: image.cap)
.appendingPathComponent(String(format: "%04d-%02d.jpg", image.cap, image.image))
}
private func load(image: CapImage) async throws {
@@ -198,7 +221,7 @@ final class ClassifierCreator {
let url = imageUrl(base: server.appendingPathComponent("images"), image: image)
let tempFile: URL, response: URLResponse
do {
(tempFile, response) = try await URLSession.shared.download(from: url)
(tempFile, response) = try await urlSession.download(from: url)
} catch {
throw TrainingError.failedToLoadImage(image, error)
}
@@ -217,7 +240,23 @@ final class ClassifierCreator {
}
}
private func loadImages(_ list: [CapImage]) async throws {
private func loadImagesInBatches(_ list: [CapImage], batchSize: Int = 100) async throws {
guard !list.isEmpty else {
return
}
print(info: "Loading \(list.count) images...")
var startIndex = list.startIndex
while startIndex < list.endIndex {
let endIndex = min(startIndex + batchSize, list.count)
let batch = Array(list[startIndex..<endIndex])
print(info: "Completed \(startIndex - list.startIndex) / \(list.count) images")
try await loadImageBatch(batch)
startIndex = endIndex
}
print(info: "Completed \(startIndex - list.startIndex) / \(list.count) images")
}
private func loadImageBatch(_ list: [CapImage]) async throws {
var loadedImageCount = 0
var errors = [Error]()
await withTaskGroup(of: Error?.self) { group in
@@ -520,7 +559,7 @@ final class ClassifierCreator {
}
private func perform(_ request: URLRequest) async throws -> Data {
let (data, response) = try await URLSession.shared.data(for: request)
let (data, response) = try await urlSession.data(for: request)
let code = (response as! HTTPURLResponse).statusCode
guard code == 200 else {
throw TrainingError.invalidResponse(request.url!, code)

View File

@@ -9,6 +9,8 @@ struct Configuration {
let serverPath: String
let authenticationToken: String
let minimumImagesPerCap: Int
}
extension Configuration {

View File

@@ -9,6 +9,9 @@ struct ConfigurationFile {
let server: String?
let authentication: String?
/// The number of images required to include a cap in training
let minimumImagesPerCap: Int?
}
extension ConfigurationFile: Decodable {

View File

@@ -87,8 +87,8 @@ extension TrainingError: CustomStringConvertible {
case .failedToSaveImage(let image, let error):
return "Failed to save \(image): \(error)"
case .failedToLoadImages(let expected, let loaded):
return "Only \(expected) of \(loaded) images loaded"
return "Only \(loaded) of \(expected) images loaded"
case .failedToGetChangedImagesList(let error):
return "Failed to get list of changed images: \(error)"
case .invalidEntryInChangeList(let entry):