Extract session parameters

This commit is contained in:
Christoph Hagen 2023-10-24 10:38:24 +02:00
parent 2913878b3e
commit a328c3a59e

View File

@ -4,6 +4,11 @@ import Combine
final class ClassifierCreator { final class ClassifierCreator {
private let scenePrintRevision: Int? = 2
private let sessionProgressIterations = 10
private let sessionCheckpointIterations = 100
private let sessionIterations = 1000
let server: URL let server: URL
let configuration: Configuration let configuration: Configuration
@ -276,13 +281,15 @@ final class ClassifierCreator {
func trainModelAsync() async throws -> MLImageClassifier { func trainModelAsync() async throws -> MLImageClassifier {
let params = MLImageClassifier.ModelParameters( let params = MLImageClassifier.ModelParameters(
maxIterations: configuration.trainingIterations, maxIterations: configuration.trainingIterations,
augmentation: []) augmentation: [],
algorithm: .transferLearning(
featureExtractor: .scenePrint(revision: scenePrintRevision),
classifier: .logisticRegressor))
let sessionParameters = MLTrainingSessionParameters( let sessionParameters = MLTrainingSessionParameters(
sessionDirectory: sessionDirectory, sessionDirectory: sessionDirectory,
reportInterval: 10, reportInterval: sessionProgressIterations,
checkpointInterval: 100, checkpointInterval: sessionCheckpointIterations,
iterations: 1000) iterations: sessionIterations)
var subscriptions = [AnyCancellable]() var subscriptions = [AnyCancellable]()