Select Git revision
keras.h 7.80 KiB
#include <iostream>
#include <cassert>
#include "json.h"
#include "matrix.h"
class Node {
protected:
json::Value config;
std::vector<Node*> inbound;
std::string name;
public:
Node() { name = "NONE"; }
virtual ~Node() { }
Node(const json::Value& _config) : config(_config) { name = config["name"].to_string();}
void setup(std::map<std::string, Node*>& nodes) {
json::Value inbound_nodes = config["inbound_nodes"];
if(inbound_nodes.length() > 0) {
for(int j = 0; j < inbound_nodes[0].length(); j++) {
std::string node_name = inbound_nodes[0][j][0].to_string();
//std::cerr << node_name << "->" << name << "\n";
if(nodes.find(node_name) != nodes.end()) {
inbound.push_back(nodes[node_name]);
} else {
std::cerr << "ERROR: cannot find inbound layer \"" << node_name << "\" when setting up layer \"" << name << "\"\n";
exit(1);
}
}
}
}
virtual void set_input(const Matrix<float> & x) { }
virtual Matrix<float> get_output() {
assert(inbound.size() == 1);
assert(inbound[0] != NULL);
return forward(inbound[0]->get_output());
}
virtual Matrix<float> forward(const Matrix<float> & x) { return x; }
};
class Embedding : public Node {
protected:
Matrix<float> W;
public:
Embedding(const json::Value& _config, FILE* storage) : Node(_config) {
fseek(storage, config["weights"]["W"].to_int(), SEEK_SET);
W.load(storage);
//W.print("W");
}
virtual Matrix<float> forward(const Matrix<float> & x) {
Matrix<float> output(x.rows, x.cols * W.cols);
for(int i = 0; i < x.rows; i++) {
for(int j = 0; j < x.cols; j++) {
int id = (int)x.at(i, j);
if(id < 0 || id >= W.rows) {
std::cerr << "WARNING: unexpected embedding id " << id << " for row " << i << " in layer " << name << ", mapping to 0\n";
id = 0;
}
//assert(id >= 0 && id < W.rows);
output[i].slice(j * W.cols, W.cols) = W[id];
}
}
return output;
}
};
class Dense : public Node {
protected:
Matrix<float> W, b;
Matrix<float> (*activation)(const Matrix<float>&);
public:
Dense(const json::Value& _config, FILE* storage) : Node(_config) {
fseek(storage, config["weights"]["W"].to_int(), SEEK_SET);
W.load(storage);
fseek(storage, config["weights"]["b"].to_int(), SEEK_SET);
b.load(storage);
std::string function = config["config"]["activation"].to_string();
if(function == "linear") activation = Matrix<float>::identity;
else if(function == "tanh") activation = Matrix<float>::tanh;
else if(function == "sigmoid") activation = Matrix<float>::sigmoid;
else if(function == "relu") activation = Matrix<float>::relu;
else if(function == "softmax") activation = Matrix<float>::softmax;
else {
std::cerr << "ERROR: unsupported activation function \"" << function << "\"\n";
exit(1);
}
}
virtual Matrix<float> forward(const Matrix<float> & x) {
return activation(x.dot(W) + b);
}
};
class Input : public Node {
Matrix<float> input;
public:
Input(const json::Value& _config) : Node(_config) { }
void set_input(const Matrix<float> &x) { input = x; }
Matrix<float> get_output() { return input; }
};
class Merge : public Node {
public:
Merge(const json::Value& _config) : Node(_config) { }
Matrix<float> get_output() {
int cols = 0;
int rows = 0;
std::vector<Matrix<float> > inputs(inbound.size());
for(size_t i = 0; i < inbound.size(); i++) {
inputs[i] = inbound[i]->get_output();
cols += inputs[i].cols;
if(rows < inputs[i].rows) rows = inputs[i].rows;
}
//std::cerr << "Merge: " << rows << "x" << cols << "\n";
Matrix<float> output(rows, cols);
int offset = 0;
for(size_t i = 0; i < inputs.size(); i++) {
for(int j = 0; j < inputs[i].rows; j++) {
output[j].slice(offset, inputs[i].cols) = inputs[i][j];
}
offset += inputs[i].cols;
}
return output;
}
};
class Identity : public Node {
public:
Identity(const json::Value& _config) : Node(_config) { }
Matrix<float> get_output() {
assert(inbound.size() == 1);
return inbound[0]->get_output();
}
};
class Model {
std::map<std::string, Node*> nodes;
std::vector<Node*> output_nodes;
std::vector<Node*> input_nodes;
public:
~Model() {
for(std::map<std::string, Node*>::iterator i = nodes.begin(); i != nodes.end(); i++) {
delete i->second;
}
}
int num_inputs() { return input_nodes.size(); }
int num_outputs() { return output_nodes.size(); }
std::vector<Matrix<float> > forward(std::vector<Matrix<float> > input) {
assert(input_nodes.size() == input.size());
for(size_t i = 0; i < input.size(); i++) {
input_nodes[i]->set_input(input[i]);
}
std::vector<Matrix<float> > output(output_nodes.size());
for(size_t i = 0; i < output_nodes.size(); i++) {
output[i] = output_nodes[i]->get_output();
}
return output;
}
Matrix<float> forward(const Matrix<float>& input) {
assert(input_nodes.size() == 1);
assert(output_nodes.size() == 1);
input_nodes[0]->set_input(input);
return output_nodes[0]->get_output();
}
static Model load(const char* json_filename, const char* storage_filename) {
Model model;
json::Value config = json::parse_file(json_filename);
FILE* storage = fopen(storage_filename, "r");
for(int i = 0; i < config["config"]["layers"].length(); i++) {
json::Value layer = config["config"]["layers"][i];
std::string name = layer["name"].to_string();
std::string class_name = layer["class_name"].to_string();
std::vector<std::string> inbound;
if(class_name == "Dense") model.nodes[name] = new Dense(layer, storage);
else if(class_name == "Embedding") model.nodes[name] = new Embedding(layer, storage);
else if(class_name == "Merge") model.nodes[name] = new Merge(layer);
else if(class_name == "Dropout") model.nodes[name] = new Identity(layer);
else if(class_name == "InputLayer") model.nodes[name] = new Input(layer);
else {
std::cerr << "ERROR: unsupported layer class \"" << class_name << "\"\n";
exit(1);
}
}
for(std::map<std::string, Node*>::iterator i = model.nodes.begin(); i != model.nodes.end(); i++) {
i->second->setup(model.nodes);
}
for(int i = 0; i < config["config"]["input_layers"].length(); i++) {
std::string name = config["config"]["input_layers"][i][0].to_string();
model.input_nodes.push_back(model.nodes[name]);
}
for(int i = 0; i < config["config"]["output_layers"].length(); i++) {
std::string name = config["config"]["output_layers"][i][0].to_string();
model.output_nodes.push_back(model.nodes[name]);
}
fclose(storage);
return model;
}
};