diff --git a/code/train.py b/code/train.py index 72c91d8b004bdf0205b5c89856cae4e388c6a225..e70902b94f0394b5f97c570ed7211fadb2defee7 100644 --- a/code/train.py +++ b/code/train.py @@ -70,6 +70,12 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb extraction_strategy=parameters['extraction_strategy'] ) pretrained_estimator = ModelFactory.build(dataset.task, pretrained_model_parameters, library=library) + pretraned_trainer = Trainer(dataset) + pretraned_trainer.init(pretrained_estimator, subsets_used=parameters['subsets_used']) + pretrained_estimator.fit( + X=pretraned_trainer._X_forest, + y=pretraned_trainer._y_forest + ) else: pretrained_estimator = None pretrained_model_parameters = None