diff --git a/Config/config_example.json b/Config/config_example.json index fdce908..8c71bca 100644 --- a/Config/config_example.json +++ b/Config/config_example.json @@ -3,4 +3,5 @@ "iterations": 20, "server": "https://mydomain.com/caps", "authentication": "mysecretkey", + "minimumImagesPerCap": 10 } diff --git a/Sources/CapTrain.swift b/Sources/CapTrain.swift index cb00b59..81c8174 100644 --- a/Sources/CapTrain.swift +++ b/Sources/CapTrain.swift @@ -6,6 +6,8 @@ struct CapTrain: AsyncParsableCommand { private static let defaultIterations = 10 + private static let defaultMinimumImageCount = 10 + @Flag(name: .shortAndLong, help: "Resume the previous training session (default: false)") var resume: Bool = false @@ -15,6 +17,9 @@ struct CapTrain: AsyncParsableCommand { @Option(name: .shortAndLong, help: "The number of iterations to train (default: 10)") var iterations: Int? + @Option(name: .shortAndLong, help: "The minimum number of images for a cap to be included in training (default: 10)") + var minimumImageCount: Int? + @Option(name: .shortAndLong, help: "The url of the caps server to retrieve images and upload the classifier") var server: String? @@ -30,6 +35,7 @@ struct CapTrain: AsyncParsableCommand { func run() async throws { let configurationFile = try configurationFile() let iterations = iterations ?? configurationFile?.iterations ?? CapTrain.defaultIterations + let minimumCount = minimumImageCount ?? configurationFile?.minimumImagesPerCap ?? CapTrain.defaultMinimumImageCount guard let contentFolder = folder ?? configurationFile?.folder else { throw TrainingError.missingArguments("folder") } @@ -43,8 +49,9 @@ struct CapTrain: AsyncParsableCommand { contentFolder: contentFolder, trainingIterations: iterations, serverPath: serverPath, - authenticationToken: authentication) + authenticationToken: authentication, + minimumImagesPerCap: minimumCount) let creator = try ClassifierCreator(configuration: configuration, resume: resume) try await creator.run(skipTraining: skipTraining) } diff --git a/Sources/ClassifierCreator.swift b/Sources/ClassifierCreator.swift index a40fcc8..2f30e38 100644 --- a/Sources/ClassifierCreator.swift +++ b/Sources/ClassifierCreator.swift @@ -15,6 +15,9 @@ final class ClassifierCreator { let configuration: Configuration + /// The number of images required to include a cap in training + let minimumImagesPerCap: Int + let imageDirectory: URL let thumbnailDirectory: URL @@ -36,6 +39,7 @@ final class ClassifierCreator { init(configuration: Configuration, resume: Bool) throws { self.configuration = configuration self.server = try configuration.serverUrl() + self.minimumImagesPerCap = configuration.minimumImagesPerCap let contentDirectory = URL(fileURLWithPath: configuration.contentFolder) self.imageDirectory = contentDirectory.appendingPathComponent("images") self.sessionDirectory = contentDirectory.appendingPathComponent("session") @@ -106,7 +110,16 @@ final class ClassifierCreator { throw TrainingError.mainImageFolderNotCreated(error) } let imageCounts = try await getImageCounts() + .filter { id, count in + // Delete caps with small counts to ensure proper training + if count < self.minimumImagesPerCap { + print(info: "Excluding cap \(id) from training (\(count) images)") + return false + } + return true + } let missingImageList: [CapImage] = imageCounts + .filter { $0.value >= self.minimumImagesPerCap } .sorted { $0.key < $1.key } .reduce(into: []) { list, pair in let missingImagesForCap: [CapImage] = (0.. URL { + base.appendingPathComponent(String(format: "%04d", cap)) + } + private func imageUrl(base: URL, image: CapImage) -> URL { - base.appendingPathComponent(String(format: "%04d/%04d-%02d.jpg", image.cap, image.cap, image.image)) + capFolderUrl(base: base, cap: image.cap) + .appendingPathComponent(String(format: "%04d-%02d.jpg", image.cap, image.image)) } private func load(image: CapImage) async throws { @@ -222,7 +240,23 @@ final class ClassifierCreator { } } - private func loadImages(_ list: [CapImage]) async throws { + private func loadImagesInBatches(_ list: [CapImage], batchSize: Int = 100) async throws { + guard !list.isEmpty else { + return + } + print(info: "Loading \(list.count) images...") + var startIndex = list.startIndex + while startIndex < list.endIndex { + let endIndex = min(startIndex + batchSize, list.count) + let batch = Array(list[startIndex..