Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Franck Dary
macaon
Commits
b13669bd
Commit
b13669bd
authored
Aug 04, 2020
by
Franck Dary
Browse files
Added program arguments : scaleGrad and maxNorm
parent
397e390f
Changes
23
Hide whitespace changes
Inline
Side-by-side
torch_modules/include/ContextModule.hpp
View file @
b13669bd
...
...
@@ -9,12 +9,13 @@
#include
"LSTM.hpp"
#include
"Concat.hpp"
#include
"Transformer.hpp"
#include
"WordEmbeddings.hpp"
class
ContextModuleImpl
:
public
Submodule
{
private
:
t
or
ch
::
nn
::
Embedding
wordEmbeddings
{
nullptr
};
W
or
d
Embedding
s
wordEmbeddings
{
nullptr
};
std
::
shared_ptr
<
MyModule
>
myModule
{
nullptr
};
std
::
vector
<
std
::
string
>
columns
;
std
::
vector
<
std
::
function
<
std
::
string
(
const
std
::
string
&
)
>>
functions
;
...
...
torch_modules/include/ContextualModule.hpp
View file @
b13669bd
...
...
@@ -8,12 +8,13 @@
#include
"GRU.hpp"
#include
"LSTM.hpp"
#include
"Concat.hpp"
#include
"WordEmbeddings.hpp"
class
ContextualModuleImpl
:
public
Submodule
{
private
:
t
or
ch
::
nn
::
Embedding
wordEmbeddings
{
nullptr
};
W
or
d
Embedding
s
wordEmbeddings
{
nullptr
};
std
::
shared_ptr
<
MyModule
>
myModule
{
nullptr
};
std
::
vector
<
std
::
string
>
columns
;
std
::
vector
<
std
::
function
<
std
::
string
(
const
std
::
string
&
)
>>
functions
;
...
...
torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
View file @
b13669bd
...
...
@@ -7,6 +7,7 @@
#include
"LSTM.hpp"
#include
"GRU.hpp"
#include
"Concat.hpp"
#include
"WordEmbeddings.hpp"
class
DepthLayerTreeEmbeddingModuleImpl
:
public
Submodule
{
...
...
@@ -16,7 +17,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
std
::
vector
<
std
::
string
>
columns
;
std
::
vector
<
int
>
focusedBuffer
;
std
::
vector
<
int
>
focusedStack
;
t
or
ch
::
nn
::
Embedding
wordEmbeddings
{
nullptr
};
W
or
d
Embedding
s
wordEmbeddings
{
nullptr
};
std
::
vector
<
std
::
shared_ptr
<
MyModule
>>
depthModules
;
int
inSize
;
...
...
torch_modules/include/DistanceModule.hpp
View file @
b13669bd
...
...
@@ -7,12 +7,13 @@
#include
"LSTM.hpp"
#include
"GRU.hpp"
#include
"Concat.hpp"
#include
"WordEmbeddings.hpp"
class
DistanceModuleImpl
:
public
Submodule
{
private
:
t
or
ch
::
nn
::
Embedding
wordEmbeddings
{
nullptr
};
W
or
d
Embedding
s
wordEmbeddings
{
nullptr
};
std
::
shared_ptr
<
MyModule
>
myModule
{
nullptr
};
std
::
vector
<
int
>
fromBuffer
,
fromStack
;
std
::
vector
<
int
>
toBuffer
,
toStack
;
...
...
torch_modules/include/FocusedColumnModule.hpp
View file @
b13669bd
...
...
@@ -7,12 +7,13 @@
#include
"LSTM.hpp"
#include
"GRU.hpp"
#include
"Concat.hpp"
#include
"WordEmbeddings.hpp"
class
FocusedColumnModuleImpl
:
public
Submodule
{
private
:
t
or
ch
::
nn
::
Embedding
wordEmbeddings
{
nullptr
};
W
or
d
Embedding
s
wordEmbeddings
{
nullptr
};
std
::
shared_ptr
<
MyModule
>
myModule
{
nullptr
};
std
::
vector
<
int
>
focusedBuffer
,
focusedStack
;
std
::
string
column
;
...
...
torch_modules/include/HistoryModule.hpp
View file @
b13669bd
...
...
@@ -8,12 +8,13 @@
#include
"GRU.hpp"
#include
"CNN.hpp"
#include
"Concat.hpp"
#include
"WordEmbeddings.hpp"
class
HistoryModuleImpl
:
public
Submodule
{
private
:
t
or
ch
::
nn
::
Embedding
wordEmbeddings
{
nullptr
};
W
or
d
Embedding
s
wordEmbeddings
{
nullptr
};
std
::
shared_ptr
<
MyModule
>
myModule
{
nullptr
};
int
maxNbElements
;
int
inSize
;
...
...
torch_modules/include/RawInputModule.hpp
View file @
b13669bd
...
...
@@ -7,12 +7,13 @@
#include
"LSTM.hpp"
#include
"GRU.hpp"
#include
"Concat.hpp"
#include
"WordEmbeddings.hpp"
class
RawInputModuleImpl
:
public
Submodule
{
private
:
t
or
ch
::
nn
::
Embedding
wordEmbeddings
{
nullptr
};
W
or
d
Embedding
s
wordEmbeddings
{
nullptr
};
std
::
shared_ptr
<
MyModule
>
myModule
{
nullptr
};
int
leftWindow
,
rightWindow
;
int
inSize
;
...
...
torch_modules/include/SplitTransModule.hpp
View file @
b13669bd
...
...
@@ -7,12 +7,13 @@
#include
"LSTM.hpp"
#include
"GRU.hpp"
#include
"Concat.hpp"
#include
"WordEmbeddings.hpp"
class
SplitTransModuleImpl
:
public
Submodule
{
private
:
t
or
ch
::
nn
::
Embedding
wordEmbeddings
{
nullptr
};
W
or
d
Embedding
s
wordEmbeddings
{
nullptr
};
std
::
shared_ptr
<
MyModule
>
myModule
{
nullptr
};
int
maxNbTrans
;
int
inSize
;
...
...
torch_modules/include/StateNameModule.hpp
View file @
b13669bd
...
...
@@ -6,12 +6,13 @@
#include
"MyModule.hpp"
#include
"LSTM.hpp"
#include
"GRU.hpp"
#include
"WordEmbeddings.hpp"
class
StateNameModuleImpl
:
public
Submodule
{
private
:
t
or
ch
::
nn
::
Embedding
embeddings
{
nullptr
};
W
or
d
Embedding
s
embeddings
{
nullptr
};
int
outSize
;
public
:
...
...
torch_modules/include/Submodule.hpp
View file @
b13669bd
...
...
@@ -16,7 +16,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde
public
:
void
setFirstInputIndex
(
std
::
size_t
firstInputIndex
);
void
loadPretrainedW2vEmbeddings
(
torch
::
nn
::
Embedding
&
embeddings
,
std
::
filesystem
::
path
path
,
std
::
string
prefix
);
void
loadPretrainedW2vEmbeddings
(
torch
::
nn
::
Embedding
embeddings
,
std
::
filesystem
::
path
path
,
std
::
string
prefix
);
virtual
std
::
size_t
getOutputSize
()
=
0
;
virtual
std
::
size_t
getInputSize
()
=
0
;
virtual
void
addToContext
(
std
::
vector
<
std
::
vector
<
long
>>
&
context
,
const
Config
&
config
)
=
0
;
...
...
torch_modules/include/WordEmbeddings.hpp
0 → 100644
View file @
b13669bd
#ifndef WORDEMBEDDINGS__H
#define WORDEMBEDDINGS__H
#include
"torch/torch.h"
class
WordEmbeddingsImpl
:
public
torch
::
nn
::
Module
{
private
:
static
bool
scaleGradByFreq
;
static
float
maxNorm
;
private
:
torch
::
nn
::
Embedding
embeddings
{
nullptr
};
public
:
static
void
setScaleGradByFreq
(
bool
scaleGradByFreq
);
static
void
setMaxNorm
(
float
maxNorm
);
WordEmbeddingsImpl
(
std
::
size_t
vocab
,
std
::
size_t
dim
);
torch
::
nn
::
Embedding
get
();
torch
::
Tensor
forward
(
torch
::
Tensor
input
);
};
TORCH_MODULE
(
WordEmbeddings
);
#endif
torch_modules/src/ContextModule.cpp
View file @
b13669bd
...
...
@@ -161,12 +161,12 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
void
ContextModuleImpl
::
registerEmbeddings
()
{
wordEmbeddings
=
register_module
(
"embeddings"
,
t
or
ch
::
nn
::
Embedding
(
torch
::
nn
::
Embedding
Option
s
(
getDict
().
size
(),
inSize
))
)
;
wordEmbeddings
=
register_module
(
"embeddings"
,
W
or
d
Embeddings
(
getDict
().
size
(),
inSize
));
auto
pathes
=
util
::
split
(
w2vFiles
.
string
(),
' '
);
for
(
auto
&
p
:
pathes
)
{
auto
splited
=
util
::
split
(
p
,
','
);
loadPretrainedW2vEmbeddings
(
wordEmbeddings
,
path
/
splited
[
1
],
splited
[
0
]);
loadPretrainedW2vEmbeddings
(
wordEmbeddings
->
get
()
,
path
/
splited
[
1
],
splited
[
0
]);
}
}
torch_modules/src/ContextualModule.cpp
View file @
b13669bd
...
...
@@ -210,13 +210,13 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
void
ContextualModuleImpl
::
registerEmbeddings
()
{
wordEmbeddings
=
register_module
(
"embeddings"
,
t
or
ch
::
nn
::
Embedding
(
torch
::
nn
::
Embedding
Option
s
(
getDict
().
size
(),
inSize
))
)
;
wordEmbeddings
=
register_module
(
"embeddings"
,
W
or
d
Embeddings
(
getDict
().
size
(),
inSize
));
auto
pathes
=
util
::
split
(
w2vFiles
.
string
(),
' '
);
for
(
auto
&
p
:
pathes
)
{
auto
splited
=
util
::
split
(
p
,
','
);
loadPretrainedW2vEmbeddings
(
wordEmbeddings
,
path
/
splited
[
1
],
splited
[
0
]);
loadPretrainedW2vEmbeddings
(
wordEmbeddings
->
get
()
,
path
/
splited
[
1
],
splited
[
0
]);
}
}
torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
View file @
b13669bd
...
...
@@ -126,6 +126,6 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon
void
DepthLayerTreeEmbeddingModuleImpl
::
registerEmbeddings
()
{
wordEmbeddings
=
register_module
(
"embeddings"
,
t
or
ch
::
nn
::
Embedding
(
torch
::
nn
::
Embedding
Option
s
(
getDict
().
size
(),
inSize
))
)
;
wordEmbeddings
=
register_module
(
"embeddings"
,
W
or
d
Embeddings
(
getDict
().
size
(),
inSize
));
}
torch_modules/src/DistanceModule.cpp
View file @
b13669bd
...
...
@@ -111,6 +111,6 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context,
void
DistanceModuleImpl
::
registerEmbeddings
()
{
wordEmbeddings
=
register_module
(
"embeddings"
,
t
or
ch
::
nn
::
Embedding
(
torch
::
nn
::
Embedding
Option
s
(
getDict
().
size
(),
inSize
))
)
;
wordEmbeddings
=
register_module
(
"embeddings"
,
W
or
d
Embeddings
(
getDict
().
size
(),
inSize
));
}
torch_modules/src/FocusedColumnModule.cpp
View file @
b13669bd
...
...
@@ -156,12 +156,12 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
void
FocusedColumnModuleImpl
::
registerEmbeddings
()
{
wordEmbeddings
=
register_module
(
"embeddings"
,
t
or
ch
::
nn
::
Embedding
(
torch
::
nn
::
Embedding
Option
s
(
getDict
().
size
(),
inSize
))
)
;
wordEmbeddings
=
register_module
(
"embeddings"
,
W
or
d
Embeddings
(
getDict
().
size
(),
inSize
));
auto
pathes
=
util
::
split
(
w2vFiles
.
string
(),
' '
);
for
(
auto
&
p
:
pathes
)
{
auto
splited
=
util
::
split
(
p
,
','
);
loadPretrainedW2vEmbeddings
(
wordEmbeddings
,
path
/
splited
[
1
],
splited
[
0
]);
loadPretrainedW2vEmbeddings
(
wordEmbeddings
->
get
()
,
path
/
splited
[
1
],
splited
[
0
]);
}
}
torch_modules/src/HistoryModule.cpp
View file @
b13669bd
...
...
@@ -69,6 +69,6 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
void
HistoryModuleImpl
::
registerEmbeddings
()
{
wordEmbeddings
=
register_module
(
"embeddings"
,
t
or
ch
::
nn
::
Embedding
(
torch
::
nn
::
Embedding
Option
s
(
getDict
().
size
(),
inSize
))
)
;
wordEmbeddings
=
register_module
(
"embeddings"
,
W
or
d
Embeddings
(
getDict
().
size
(),
inSize
));
}
torch_modules/src/RawInputModule.cpp
View file @
b13669bd
...
...
@@ -78,6 +78,6 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context,
void
RawInputModuleImpl
::
registerEmbeddings
()
{
wordEmbeddings
=
register_module
(
"embeddings"
,
t
or
ch
::
nn
::
Embedding
(
torch
::
nn
::
Embedding
Option
s
(
getDict
().
size
(),
inSize
))
)
;
wordEmbeddings
=
register_module
(
"embeddings"
,
W
or
d
Embeddings
(
getDict
().
size
(),
inSize
));
}
torch_modules/src/SplitTransModule.cpp
View file @
b13669bd
...
...
@@ -65,6 +65,6 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context
void
SplitTransModuleImpl
::
registerEmbeddings
()
{
wordEmbeddings
=
register_module
(
"embeddings"
,
t
or
ch
::
nn
::
Embedding
(
torch
::
nn
::
Embedding
Option
s
(
getDict
().
size
(),
inSize
))
)
;
wordEmbeddings
=
register_module
(
"embeddings"
,
W
or
d
Embeddings
(
getDict
().
size
(),
inSize
));
}
torch_modules/src/StateNameModule.cpp
View file @
b13669bd
...
...
@@ -38,6 +38,6 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context,
void
StateNameModuleImpl
::
registerEmbeddings
()
{
embeddings
=
register_module
(
"embeddings"
,
t
or
ch
::
nn
::
Embedding
(
getDict
().
size
(),
outSize
));
embeddings
=
register_module
(
"embeddings"
,
W
or
d
Embedding
s
(
getDict
().
size
(),
outSize
));
}
Prev
1
2
Next
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment