diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index c087469af57b04e7e7854c2569046988dd6ccd3b..c2099fd6a8a21e897fc9490e4f7732d5ca613ec3 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -9,6 +9,8 @@ class Trainer { private : + static constexpr std::size_t safetyNbExamplesMax = 10*1000*1000; + struct Examples { std::vector<torch::Tensor> contexts; diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index c19ca2c25270c145fe2222f2d57ea6f0841854e4..5306a2cf8cffb1beebd30a1198d6ae386058937d 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -122,6 +122,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p int goldIndex = machine.getTransitionSet().getTransitionIndex(goldTransition); totalNbExamples += context.size(); + if (totalNbExamples >= (int)safetyNbExamplesMax) + util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax))); examplesPerState[config.getState()].addContext(context); examplesPerState[config.getState()].addClass(goldIndex);