Skip to content
Snippets Groups Projects
Commit 6e1d31e3 authored by Franck Dary's avatar Franck Dary
Browse files

Added class Dict and worked on DataLoader

parent 0733a6af
Branches
No related tags found
No related merge requests found
#ifndef DICT__H
#define DICT__H
#include <string>
#include <unordered_map>
class Dict
{
public :
enum State {Open, Closed};
enum Encoding {Binary, Ascii};
private :
static constexpr char const * unknownValueStr = "__unknownValue__";
static constexpr char const * nullValueStr = "__nullValue__";
static constexpr std::size_t maxEntrySize = 5000;
private :
std::unordered_map<std::string, int> elementsToIndexes;
State state;
public :
Dict(State state);
Dict(const char * filename, State state);
private :
void readFromFile(const char * filename);
public :
void insert(const std::string & element);
int getIndexOrInsert(const std::string & element);
void setState(State state);
void save(std::FILE * destination, Encoding encoding);
bool readEntry(std::FILE * file, int * index, char * entry, Encoding encoding);
void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding);
};
#endif
#include "Dict.hpp"
#include "util.hpp"
Dict::Dict(State state)
{
setState(state);
insert(unknownValueStr);
insert(nullValueStr);
}
Dict::Dict(const char * filename, State state)
{
readFromFile(filename);
setState(state);
}
void Dict::readFromFile(const char * filename)
{
std::FILE * file = std::fopen(filename, "r");
if (!file)
util::myThrow(fmt::format("could not open file \'%s\'", filename));
char buffer[1048];
if (std::fscanf(file, "Encoding : %1047s\n", buffer) != 1)
util::myThrow(fmt::format("file \'%s\' bad format", filename));
Encoding encoding{Encoding::Ascii};
if (std::string(buffer) == "Ascii")
encoding = Encoding::Ascii;
else if (std::string(buffer) == "Binary")
encoding = Encoding::Binary;
else
util::myThrow(fmt::format("file \'%s\' bad format", filename));
int nbEntries;
if (std::fscanf(file, "Nb entries : %d\n", &nbEntries) != 1)
util::myThrow(fmt::format("file \'%s\' bad format", filename));
elementsToIndexes.reserve(nbEntries);
int entryIndex;
char entryString[maxEntrySize+1];
for (int i = 0; i < nbEntries; i++)
{
if (!readEntry(file, &entryIndex, entryString, encoding))
util::myThrow(fmt::format("file \'%s\' line {} bad format", filename, i));
elementsToIndexes[entryString] = entryIndex;
}
std::fclose(file);
}
void Dict::insert(const std::string & element)
{
if (element.size() > maxEntrySize)
util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize));
elementsToIndexes.emplace(element, elementsToIndexes.size());
}
int Dict::getIndexOrInsert(const std::string & element)
{
if (state == State::Open)
insert(element);
const auto & found = elementsToIndexes.find(element);
if (found == elementsToIndexes.end())
return elementsToIndexes[unknownValueStr];
return found->second;
}
void Dict::setState(State state)
{
this->state = state;
}
void Dict::save(std::FILE * destination, Encoding encoding)
{
fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary");
fprintf(destination, "Nb entries : %lu\n", elementsToIndexes.size());
for (auto & it : elementsToIndexes)
printEntry(destination, it.second, it.first, encoding);
}
bool Dict::readEntry(std::FILE * file, int * index, char * entry, Encoding encoding)
{
if (encoding == Encoding::Ascii)
{
static std::string readFormat = "%d\t%"+std::to_string(maxEntrySize)+"[^\n]\n";
return fscanf(file, readFormat.c_str(), index, entry) == 2;
}
else
{
if (std::fread(index, sizeof *index, 1, file) != 1)
return false;
for (unsigned int i = 0; i < maxEntrySize; i++)
{
if (std::fread(entry+i, 1, 1, file) != 1)
return false;
if (!entry[i])
return true;
}
return false;
}
}
void Dict::printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding)
{
if (encoding == Encoding::Ascii)
{
static std::string printFormat = "%d\t%s\n";
fprintf(file, printFormat.c_str(), index, entry.c_str());
}
else
{
std::fwrite(&index, sizeof index, 1, file);
std::fwrite(entry.c_str(), 1, entry.size()+1, file);
}
}
...@@ -36,7 +36,12 @@ int main(int argc, char * argv[]) ...@@ -36,7 +36,12 @@ int main(int argc, char * argv[])
std::vector<torch::Tensor> predictionsBatch; std::vector<torch::Tensor> predictionsBatch;
std::vector<torch::Tensor> referencesBatch; std::vector<torch::Tensor> referencesBatch;
std::vector<SubConfig> configs; std::vector<std::unique_ptr<Config>> configs;
std::vector<std::size_t> classes;
fmt::print("Generating dataset...");
Dict dict(Dict::State::Open);
while (true) while (true)
{ {
...@@ -54,10 +59,8 @@ int main(int argc, char * argv[]) ...@@ -54,10 +59,8 @@ int main(int argc, char * argv[])
// auto loss = torch::nll_loss(prediction, gold); // auto loss = torch::nll_loss(prediction, gold);
// loss.backward(); // loss.backward();
// optimizer.step(); // optimizer.step();
configs.emplace_back(config); configs.emplace_back(std::unique_ptr<Config>(new SubConfig(config)));
classes.emplace_back(goldIndex);
if (config.getWordIndex()%1 == 0)
fmt::print("{:.5f}%\n", config.getWordIndex()*100.0/goldConfig.getNbLines());
// if (config.getWordIndex() >= 500) // if (config.getWordIndex() >= 500)
// exit(1); // exit(1);
...@@ -77,6 +80,18 @@ int main(int argc, char * argv[]) ...@@ -77,6 +80,18 @@ int main(int argc, char * argv[])
config.update(); config.update();
} }
auto dataset = ConfigDataset(configs, classes, machine.getTransitionSet().size(), dict).map(torch::data::transforms::Stack<>());
fmt::print("Done!\n");
auto dataLoader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(std::move(dataset), 50);
for (auto & batch : *dataLoader)
{
auto data = batch.data;
auto labels = batch.target.squeeze();
}
return 0; return 0;
} }
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <boost/flyweight.hpp> #include <boost/flyweight.hpp>
#include <boost/circular_buffer.hpp> #include <boost/circular_buffer.hpp>
#include "util.hpp" #include "util.hpp"
#include "Dict.hpp"
class Config class Config
{ {
...@@ -15,6 +16,8 @@ class Config ...@@ -15,6 +16,8 @@ class Config
static constexpr const char * EOSColName = "EOS"; static constexpr const char * EOSColName = "EOS";
static constexpr const char * EOSSymbol1 = "1"; static constexpr const char * EOSSymbol1 = "1";
static constexpr const char * EOSSymbol0 = "0"; static constexpr const char * EOSSymbol0 = "0";
static constexpr const char * headColName = "HEAD";
static constexpr const char * idColName = "ID";
static constexpr int nbHypothesesMax = 1; static constexpr int nbHypothesesMax = 1;
public : public :
...@@ -96,6 +99,7 @@ class Config ...@@ -96,6 +99,7 @@ class Config
String getState() const; String getState() const;
void setState(const std::string state); void setState(const std::string state);
bool stateIsDone() const; bool stateIsDone() const;
std::vector<int> extractContext(int leftBorder, int rightBorder, Dict & dict) const;
}; };
......
...@@ -366,3 +366,31 @@ bool Config::stateIsDone() const ...@@ -366,3 +366,31 @@ bool Config::stateIsDone() const
return !has(0, wordIndex+1, 0); return !has(0, wordIndex+1, 0);
} }
std::vector<int> Config::extractContext(int leftBorder, int rightBorder, Dict & dict) const
{
std::vector<int> context;
int startIndex = wordIndex;
for (int i = 0; i < leftBorder and has(0,startIndex-1,0); i++)
do
--startIndex;
while (!isToken(startIndex) and has(0,startIndex-1,0));
int endIndex = wordIndex;
for (int i = 0; i < rightBorder and has(0,endIndex+1,0); i++)
do
++endIndex;
while (!isToken(endIndex) and has(0,endIndex+1,0));
for (int i = startIndex; i <= endIndex; ++i)
if (isToken(i))
context.emplace_back(dict.getIndexOrInsert(getLastNotEmptyConst("FORM", i)));
//TODO gérer les cas où la taille est differente...
return {0};
return context;
}
...@@ -8,12 +8,14 @@ class ConfigDataset : public torch::data::Dataset<ConfigDataset> ...@@ -8,12 +8,14 @@ class ConfigDataset : public torch::data::Dataset<ConfigDataset>
{ {
private : private :
std::vector<Config> const & configs; std::vector<std::unique_ptr<Config>> const & configs;
std::vector<std::size_t> const & classes; std::vector<std::size_t> const & classes;
std::size_t nbClasses;
Dict & dict;
public : public :
ConfigDataset(std::vector<Config> const & configs, std::vector<std::size_t> const & classes); explicit ConfigDataset(std::vector<std::unique_ptr<Config>> const & configs, std::vector<std::size_t> const & classes, std::size_t nbClasses, Dict & dict);
torch::optional<size_t> size() const override; torch::optional<size_t> size() const override;
torch::data::Example<> get(size_t index) override; torch::data::Example<> get(size_t index) override;
}; };
......
#include "ConfigDataset.hpp" #include "ConfigDataset.hpp"
ConfigDataset::ConfigDataset(std::vector<Config> const & configs, std::vector<std::size_t> const & classes) : configs(configs), classes(classes) ConfigDataset::ConfigDataset(std::vector<std::unique_ptr<Config>> const & configs, std::vector<std::size_t> const & classes, std::size_t nbClasses, Dict & dict) : configs(configs), classes(classes), nbClasses(nbClasses), dict(dict)
{ {
} }
torch::optional<size_t> ConfigDataset::size() const torch::optional<size_t> ConfigDataset::size() const
{ {
return configs.size();
} }
torch::data::Example<> ConfigDataset::get(size_t index) torch::data::Example<> ConfigDataset::get(size_t index)
{ {
auto context = configs[index]->extractContext(5,5,dict);
auto tensorClass = torch::zeros(nbClasses);
tensorClass[classes[index]] = 1;
return {torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone(), tensorClass};
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment