Exclude caps with too few images
This commit is contained in:
parent
93af0cb76d
commit
dec1154006
@ -3,4 +3,5 @@
|
|||||||
"iterations": 20,
|
"iterations": 20,
|
||||||
"server": "https://mydomain.com/caps",
|
"server": "https://mydomain.com/caps",
|
||||||
"authentication": "mysecretkey",
|
"authentication": "mysecretkey",
|
||||||
|
"minimumImagesPerCap": 10
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,8 @@ struct CapTrain: AsyncParsableCommand {
|
|||||||
|
|
||||||
private static let defaultIterations = 10
|
private static let defaultIterations = 10
|
||||||
|
|
||||||
|
private static let defaultMinimumImageCount = 10
|
||||||
|
|
||||||
@Flag(name: .shortAndLong, help: "Resume the previous training session (default: false)")
|
@Flag(name: .shortAndLong, help: "Resume the previous training session (default: false)")
|
||||||
var resume: Bool = false
|
var resume: Bool = false
|
||||||
|
|
||||||
@ -15,6 +17,9 @@ struct CapTrain: AsyncParsableCommand {
|
|||||||
@Option(name: .shortAndLong, help: "The number of iterations to train (default: 10)")
|
@Option(name: .shortAndLong, help: "The number of iterations to train (default: 10)")
|
||||||
var iterations: Int?
|
var iterations: Int?
|
||||||
|
|
||||||
|
@Option(name: .shortAndLong, help: "The minimum number of images for a cap to be included in training (default: 10)")
|
||||||
|
var minimumImageCount: Int?
|
||||||
|
|
||||||
@Option(name: .shortAndLong, help: "The url of the caps server to retrieve images and upload the classifier")
|
@Option(name: .shortAndLong, help: "The url of the caps server to retrieve images and upload the classifier")
|
||||||
var server: String?
|
var server: String?
|
||||||
|
|
||||||
@ -30,6 +35,7 @@ struct CapTrain: AsyncParsableCommand {
|
|||||||
func run() async throws {
|
func run() async throws {
|
||||||
let configurationFile = try configurationFile()
|
let configurationFile = try configurationFile()
|
||||||
let iterations = iterations ?? configurationFile?.iterations ?? CapTrain.defaultIterations
|
let iterations = iterations ?? configurationFile?.iterations ?? CapTrain.defaultIterations
|
||||||
|
let minimumCount = minimumImageCount ?? configurationFile?.minimumImagesPerCap ?? CapTrain.defaultMinimumImageCount
|
||||||
guard let contentFolder = folder ?? configurationFile?.folder else {
|
guard let contentFolder = folder ?? configurationFile?.folder else {
|
||||||
throw TrainingError.missingArguments("folder")
|
throw TrainingError.missingArguments("folder")
|
||||||
}
|
}
|
||||||
@ -43,8 +49,9 @@ struct CapTrain: AsyncParsableCommand {
|
|||||||
contentFolder: contentFolder,
|
contentFolder: contentFolder,
|
||||||
trainingIterations: iterations,
|
trainingIterations: iterations,
|
||||||
serverPath: serverPath,
|
serverPath: serverPath,
|
||||||
authenticationToken: authentication)
|
|
||||||
|
|
||||||
|
authenticationToken: authentication,
|
||||||
|
minimumImagesPerCap: minimumCount)
|
||||||
let creator = try ClassifierCreator(configuration: configuration, resume: resume)
|
let creator = try ClassifierCreator(configuration: configuration, resume: resume)
|
||||||
try await creator.run(skipTraining: skipTraining)
|
try await creator.run(skipTraining: skipTraining)
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,9 @@ final class ClassifierCreator {
|
|||||||
|
|
||||||
let configuration: Configuration
|
let configuration: Configuration
|
||||||
|
|
||||||
|
/// The number of images required to include a cap in training
|
||||||
|
let minimumImagesPerCap: Int
|
||||||
|
|
||||||
let imageDirectory: URL
|
let imageDirectory: URL
|
||||||
|
|
||||||
let thumbnailDirectory: URL
|
let thumbnailDirectory: URL
|
||||||
@ -36,6 +39,7 @@ final class ClassifierCreator {
|
|||||||
init(configuration: Configuration, resume: Bool) throws {
|
init(configuration: Configuration, resume: Bool) throws {
|
||||||
self.configuration = configuration
|
self.configuration = configuration
|
||||||
self.server = try configuration.serverUrl()
|
self.server = try configuration.serverUrl()
|
||||||
|
self.minimumImagesPerCap = configuration.minimumImagesPerCap
|
||||||
let contentDirectory = URL(fileURLWithPath: configuration.contentFolder)
|
let contentDirectory = URL(fileURLWithPath: configuration.contentFolder)
|
||||||
self.imageDirectory = contentDirectory.appendingPathComponent("images")
|
self.imageDirectory = contentDirectory.appendingPathComponent("images")
|
||||||
self.sessionDirectory = contentDirectory.appendingPathComponent("session")
|
self.sessionDirectory = contentDirectory.appendingPathComponent("session")
|
||||||
@ -106,7 +110,16 @@ final class ClassifierCreator {
|
|||||||
throw TrainingError.mainImageFolderNotCreated(error)
|
throw TrainingError.mainImageFolderNotCreated(error)
|
||||||
}
|
}
|
||||||
let imageCounts = try await getImageCounts()
|
let imageCounts = try await getImageCounts()
|
||||||
|
.filter { id, count in
|
||||||
|
// Delete caps with small counts to ensure proper training
|
||||||
|
if count < self.minimumImagesPerCap {
|
||||||
|
print(info: "Excluding cap \(id) from training (\(count) images)")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
let missingImageList: [CapImage] = imageCounts
|
let missingImageList: [CapImage] = imageCounts
|
||||||
|
.filter { $0.value >= self.minimumImagesPerCap }
|
||||||
.sorted { $0.key < $1.key }
|
.sorted { $0.key < $1.key }
|
||||||
.reduce(into: []) { list, pair in
|
.reduce(into: []) { list, pair in
|
||||||
let missingImagesForCap: [CapImage] = (0..<pair.value).compactMap { image in
|
let missingImagesForCap: [CapImage] = (0..<pair.value).compactMap { image in
|
||||||
@ -123,7 +136,7 @@ final class ClassifierCreator {
|
|||||||
print(info: "No missing images to load")
|
print(info: "No missing images to load")
|
||||||
} else {
|
} else {
|
||||||
print(info: "Loading \(missingImageList.count) missing images...")
|
print(info: "Loading \(missingImageList.count) missing images...")
|
||||||
try await loadImages(missingImageList)
|
try await loadImagesInBatches(missingImageList)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -138,7 +151,7 @@ final class ClassifierCreator {
|
|||||||
print(info: "No changed images to load" + suffix)
|
print(info: "No changed images to load" + suffix)
|
||||||
} else {
|
} else {
|
||||||
print(info: "Loading \(filteredChangeList.count) changed images" + suffix)
|
print(info: "Loading \(filteredChangeList.count) changed images" + suffix)
|
||||||
try await loadImages(filteredChangeList)
|
try await loadImagesInBatches(filteredChangeList)
|
||||||
}
|
}
|
||||||
|
|
||||||
let changedMainImages = changedImageList.filter { $0.image == 0 }.map { $0.cap }
|
let changedMainImages = changedImageList.filter { $0.image == 0 }.map { $0.cap }
|
||||||
@ -183,15 +196,20 @@ final class ClassifierCreator {
|
|||||||
let url = imageDirectory.appendingPathComponent(folder)
|
let url = imageDirectory.appendingPathComponent(folder)
|
||||||
do {
|
do {
|
||||||
try FileManager.default.removeItem(at: url)
|
try FileManager.default.removeItem(at: url)
|
||||||
print(info: "Removed unused image folder '\(folder)'")
|
print(info: "Removed item in image folder: '\(folder)'")
|
||||||
} catch {
|
} catch {
|
||||||
throw TrainingError.failedToRemoveImageFolder(folder, error)
|
throw TrainingError.failedToRemoveImageFolder(folder, error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private func capFolderUrl(base: URL, cap: Int) -> URL {
|
||||||
|
base.appendingPathComponent(String(format: "%04d", cap))
|
||||||
|
}
|
||||||
|
|
||||||
private func imageUrl(base: URL, image: CapImage) -> URL {
|
private func imageUrl(base: URL, image: CapImage) -> URL {
|
||||||
base.appendingPathComponent(String(format: "%04d/%04d-%02d.jpg", image.cap, image.cap, image.image))
|
capFolderUrl(base: base, cap: image.cap)
|
||||||
|
.appendingPathComponent(String(format: "%04d-%02d.jpg", image.cap, image.image))
|
||||||
}
|
}
|
||||||
|
|
||||||
private func load(image: CapImage) async throws {
|
private func load(image: CapImage) async throws {
|
||||||
@ -222,7 +240,23 @@ final class ClassifierCreator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private func loadImages(_ list: [CapImage]) async throws {
|
private func loadImagesInBatches(_ list: [CapImage], batchSize: Int = 100) async throws {
|
||||||
|
guard !list.isEmpty else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
print(info: "Loading \(list.count) images...")
|
||||||
|
var startIndex = list.startIndex
|
||||||
|
while startIndex < list.endIndex {
|
||||||
|
let endIndex = min(startIndex + batchSize, list.count)
|
||||||
|
let batch = Array(list[startIndex..<endIndex])
|
||||||
|
print(info: "Completed \(startIndex - list.startIndex) / \(list.count) images")
|
||||||
|
try await loadImageBatch(batch)
|
||||||
|
startIndex = endIndex
|
||||||
|
}
|
||||||
|
print(info: "Completed \(startIndex - list.startIndex) / \(list.count) images")
|
||||||
|
}
|
||||||
|
|
||||||
|
private func loadImageBatch(_ list: [CapImage]) async throws {
|
||||||
var loadedImageCount = 0
|
var loadedImageCount = 0
|
||||||
var errors = [Error]()
|
var errors = [Error]()
|
||||||
await withTaskGroup(of: Error?.self) { group in
|
await withTaskGroup(of: Error?.self) { group in
|
||||||
|
@ -9,6 +9,8 @@ struct Configuration {
|
|||||||
let serverPath: String
|
let serverPath: String
|
||||||
|
|
||||||
let authenticationToken: String
|
let authenticationToken: String
|
||||||
|
|
||||||
|
let minimumImagesPerCap: Int
|
||||||
}
|
}
|
||||||
|
|
||||||
extension Configuration {
|
extension Configuration {
|
||||||
|
@ -9,6 +9,9 @@ struct ConfigurationFile {
|
|||||||
let server: String?
|
let server: String?
|
||||||
|
|
||||||
let authentication: String?
|
let authentication: String?
|
||||||
|
|
||||||
|
/// The number of images required to include a cap in training
|
||||||
|
let minimumImagesPerCap: Int?
|
||||||
}
|
}
|
||||||
|
|
||||||
extension ConfigurationFile: Decodable {
|
extension ConfigurationFile: Decodable {
|
||||||
|
@ -87,8 +87,8 @@ extension TrainingError: CustomStringConvertible {
|
|||||||
case .failedToSaveImage(let image, let error):
|
case .failedToSaveImage(let image, let error):
|
||||||
return "Failed to save \(image): \(error)"
|
return "Failed to save \(image): \(error)"
|
||||||
case .failedToLoadImages(let expected, let loaded):
|
case .failedToLoadImages(let expected, let loaded):
|
||||||
return "Only \(expected) of \(loaded) images loaded"
|
|
||||||
|
|
||||||
|
return "Only \(loaded) of \(expected) images loaded"
|
||||||
case .failedToGetChangedImagesList(let error):
|
case .failedToGetChangedImagesList(let error):
|
||||||
return "Failed to get list of changed images: \(error)"
|
return "Failed to get list of changed images: \(error)"
|
||||||
case .invalidEntryInChangeList(let entry):
|
case .invalidEntryInChangeList(let entry):
|
||||||
|
Loading…
Reference in New Issue
Block a user