Extract session parameters
This commit is contained in:
parent
2913878b3e
commit
a328c3a59e
@ -4,6 +4,11 @@ import Combine
|
|||||||
|
|
||||||
final class ClassifierCreator {
|
final class ClassifierCreator {
|
||||||
|
|
||||||
|
private let scenePrintRevision: Int? = 2
|
||||||
|
private let sessionProgressIterations = 10
|
||||||
|
private let sessionCheckpointIterations = 100
|
||||||
|
private let sessionIterations = 1000
|
||||||
|
|
||||||
let server: URL
|
let server: URL
|
||||||
|
|
||||||
let configuration: Configuration
|
let configuration: Configuration
|
||||||
@ -276,13 +281,15 @@ final class ClassifierCreator {
|
|||||||
func trainModelAsync() async throws -> MLImageClassifier {
|
func trainModelAsync() async throws -> MLImageClassifier {
|
||||||
let params = MLImageClassifier.ModelParameters(
|
let params = MLImageClassifier.ModelParameters(
|
||||||
maxIterations: configuration.trainingIterations,
|
maxIterations: configuration.trainingIterations,
|
||||||
augmentation: [])
|
augmentation: [],
|
||||||
|
algorithm: .transferLearning(
|
||||||
|
featureExtractor: .scenePrint(revision: scenePrintRevision),
|
||||||
|
classifier: .logisticRegressor))
|
||||||
let sessionParameters = MLTrainingSessionParameters(
|
let sessionParameters = MLTrainingSessionParameters(
|
||||||
sessionDirectory: sessionDirectory,
|
sessionDirectory: sessionDirectory,
|
||||||
reportInterval: 10,
|
reportInterval: sessionProgressIterations,
|
||||||
checkpointInterval: 100,
|
checkpointInterval: sessionCheckpointIterations,
|
||||||
iterations: 1000)
|
iterations: sessionIterations)
|
||||||
|
|
||||||
var subscriptions = [AnyCancellable]()
|
var subscriptions = [AnyCancellable]()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user