diff --git a/Sources/ClassifierCreator.swift b/Sources/ClassifierCreator.swift index b13f4ab..e95f907 100644 --- a/Sources/ClassifierCreator.swift +++ b/Sources/ClassifierCreator.swift @@ -4,6 +4,11 @@ import Combine final class ClassifierCreator { + private let scenePrintRevision: Int? = 2 + private let sessionProgressIterations = 10 + private let sessionCheckpointIterations = 100 + private let sessionIterations = 1000 + let server: URL let configuration: Configuration @@ -276,13 +281,15 @@ final class ClassifierCreator { func trainModelAsync() async throws -> MLImageClassifier { let params = MLImageClassifier.ModelParameters( maxIterations: configuration.trainingIterations, - augmentation: []) - + augmentation: [], + algorithm: .transferLearning( + featureExtractor: .scenePrint(revision: scenePrintRevision), + classifier: .logisticRegressor)) let sessionParameters = MLTrainingSessionParameters( sessionDirectory: sessionDirectory, - reportInterval: 10, - checkpointInterval: 100, - iterations: 1000) + reportInterval: sessionProgressIterations, + checkpointInterval: sessionCheckpointIterations, + iterations: sessionIterations) var subscriptions = [AnyCancellable]()