diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp index eea032a47c010c3d329c45d9f5d6aa0e047ad20c..525e452e6834008f611a668b2c8653e510062b13 100644 --- a/decoder/src/MacaonDecode.cpp +++ b/decoder/src/MacaonDecode.cpp @@ -1,5 +1,6 @@ #include "MacaonDecode.hpp" #include <filesystem> +#include <execution> #include "util.hpp" #include "Decoder.hpp" #include "Submodule.hpp" @@ -122,11 +123,23 @@ int MacaonDecode::main() configs.emplace_back(mcd, tsv, util::utf8string(), std::vector<int>()); } - for (unsigned int i = 0; i < configs.size(); i++) + machine.setDictsState(Dict::State::Closed); + + if (configs.size() > 1) { - decoder.decode(configs[i], beamSize, beamThreshold, debug, printAdvancement); - configs[i].print(stdout, i == 0); + NeuralNetworkImpl::device = torch::kCPU; + machine.to(NeuralNetworkImpl::device); + std::for_each(std::execution::par_unseq, configs.begin(), configs.end(), + [&decoder, debug, printAdvancement, beamSize, beamThreshold](BaseConfig & config) + { + decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement); + }); } + else + decoder.decode(configs[0], beamSize, beamThreshold, debug, printAdvancement); + + for (unsigned int i = 0; i < configs.size(); i++) + configs[i].print(stdout, i == 0); } catch(std::exception & e) {util::error(e);}