Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Franck Dary
macaon
Commits
ed05ee4a
Commit
ed05ee4a
authored
Jun 05, 2020
by
Franck Dary
Browse files
Added Concat module
parent
f799c58f
Changes
18
Hide whitespace changes
Inline
Side-by-side
torch_modules/include/Concat.hpp
0 → 100644
View file @
ed05ee4a
#ifndef Concat__H
#define Concat__H
#include
<torch/torch.h>
#include
"MyModule.hpp"
class
ConcatImpl
:
public
MyModule
{
private
:
int
inputSize
;
public
:
ConcatImpl
(
int
inputSize
);
torch
::
Tensor
forward
(
torch
::
Tensor
input
);
int
getOutputSize
(
int
sequenceLength
);
};
TORCH_MODULE
(
Concat
);
#endif
torch_modules/include/ContextModule.hpp
View file @
ed05ee4a
...
...
@@ -6,6 +6,7 @@
#include
"MyModule.hpp"
#include
"GRU.hpp"
#include
"LSTM.hpp"
#include
"Concat.hpp"
class
ContextModuleImpl
:
public
Submodule
{
...
...
torch_modules/include/DepthLayerTreeEmbeddingModule.hpp
View file @
ed05ee4a
...
...
@@ -6,6 +6,7 @@
#include
"MyModule.hpp"
#include
"LSTM.hpp"
#include
"GRU.hpp"
#include
"Concat.hpp"
class
DepthLayerTreeEmbeddingModuleImpl
:
public
Submodule
{
...
...
torch_modules/include/FocusedColumnModule.hpp
View file @
ed05ee4a
...
...
@@ -6,6 +6,7 @@
#include
"MyModule.hpp"
#include
"LSTM.hpp"
#include
"GRU.hpp"
#include
"Concat.hpp"
class
FocusedColumnModuleImpl
:
public
Submodule
{
...
...
torch_modules/include/HistoryModule.hpp
View file @
ed05ee4a
...
...
@@ -6,6 +6,7 @@
#include
"MyModule.hpp"
#include
"LSTM.hpp"
#include
"GRU.hpp"
#include
"Concat.hpp"
class
HistoryModuleImpl
:
public
Submodule
{
...
...
torch_modules/include/NumericColumnModule.hpp
View file @
ed05ee4a
...
...
@@ -6,6 +6,7 @@
#include
"MyModule.hpp"
#include
"LSTM.hpp"
#include
"GRU.hpp"
#include
"Concat.hpp"
class
NumericColumnModuleImpl
:
public
Submodule
{
...
...
torch_modules/include/RawInputModule.hpp
View file @
ed05ee4a
...
...
@@ -6,6 +6,7 @@
#include
"MyModule.hpp"
#include
"LSTM.hpp"
#include
"GRU.hpp"
#include
"Concat.hpp"
class
RawInputModuleImpl
:
public
Submodule
{
...
...
torch_modules/include/SplitTransModule.hpp
View file @
ed05ee4a
...
...
@@ -6,6 +6,7 @@
#include
"MyModule.hpp"
#include
"LSTM.hpp"
#include
"GRU.hpp"
#include
"Concat.hpp"
class
SplitTransModuleImpl
:
public
Submodule
{
...
...
torch_modules/include/UppercaseRateModule.hpp
View file @
ed05ee4a
...
...
@@ -6,6 +6,7 @@
#include
"MyModule.hpp"
#include
"LSTM.hpp"
#include
"GRU.hpp"
#include
"Concat.hpp"
class
UppercaseRateModuleImpl
:
public
Submodule
{
...
...
torch_modules/src/Concat.cpp
0 → 100644
View file @
ed05ee4a
#include
"Concat.hpp"
ConcatImpl
::
ConcatImpl
(
int
inputSize
)
:
inputSize
(
inputSize
)
{
}
torch
::
Tensor
ConcatImpl
::
forward
(
torch
::
Tensor
input
)
{
return
input
.
view
({
input
.
size
(
0
),
-
1
});
}
int
ConcatImpl
::
getOutputSize
(
int
sequenceLength
)
{
return
sequenceLength
*
inputSize
;
}
torch_modules/src/ContextModule.cpp
View file @
ed05ee4a
...
...
@@ -32,6 +32,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
myModule
=
register_module
(
"myModule"
,
LSTM
(
columns
.
size
()
*
inSize
,
outSize
,
options
));
else
if
(
subModuleType
==
"GRU"
)
myModule
=
register_module
(
"myModule"
,
GRU
(
columns
.
size
()
*
inSize
,
outSize
,
options
));
else
if
(
subModuleType
==
"Concat"
)
myModule
=
register_module
(
"myModule"
,
Concat
(
inSize
));
else
util
::
myThrow
(
fmt
::
format
(
"unknown sumodule type '{}'"
,
subModuleType
));
...
...
torch_modules/src/DepthLayerTreeEmbeddingModule.cpp
View file @
ed05ee4a
...
...
@@ -38,6 +38,8 @@ DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string
depthModules
.
emplace_back
(
register_module
(
name
,
LSTM
(
columns
.
size
()
*
inSize
,
outSize
,
options
)));
else
if
(
subModuleType
==
"GRU"
)
depthModules
.
emplace_back
(
register_module
(
name
,
GRU
(
columns
.
size
()
*
inSize
,
outSize
,
options
)));
else
if
(
subModuleType
==
"Concat"
)
depthModules
.
emplace_back
(
register_module
(
name
,
Concat
(
inSize
)));
else
util
::
myThrow
(
fmt
::
format
(
"unknown sumodule type '{}'"
,
subModuleType
));
}
...
...
torch_modules/src/FocusedColumnModule.cpp
View file @
ed05ee4a
...
...
@@ -33,6 +33,8 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st
myModule
=
register_module
(
"myModule"
,
LSTM
(
inSize
,
outSize
,
options
));
else
if
(
subModuleType
==
"GRU"
)
myModule
=
register_module
(
"myModule"
,
GRU
(
inSize
,
outSize
,
options
));
else
if
(
subModuleType
==
"Concat"
)
myModule
=
register_module
(
"myModule"
,
Concat
(
inSize
));
else
util
::
myThrow
(
fmt
::
format
(
"unknown sumodule type '{}'"
,
subModuleType
));
...
...
torch_modules/src/HistoryModule.cpp
View file @
ed05ee4a
...
...
@@ -26,6 +26,8 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin
myModule
=
register_module
(
"myModule"
,
LSTM
(
inSize
,
outSize
,
options
));
else
if
(
subModuleType
==
"GRU"
)
myModule
=
register_module
(
"myModule"
,
GRU
(
inSize
,
outSize
,
options
));
else
if
(
subModuleType
==
"Concat"
)
myModule
=
register_module
(
"myModule"
,
Concat
(
inSize
));
else
util
::
myThrow
(
fmt
::
format
(
"unknown sumodule type '{}'"
,
subModuleType
));
...
...
torch_modules/src/NumericColumnModule.cpp
View file @
ed05ee4a
...
...
@@ -32,6 +32,8 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
myModule
=
register_module
(
"myModule"
,
LSTM
(
1
,
outSize
,
options
));
else
if
(
subModuleType
==
"GRU"
)
myModule
=
register_module
(
"myModule"
,
GRU
(
1
,
outSize
,
options
));
else
if
(
subModuleType
==
"Concat"
)
myModule
=
register_module
(
"myModule"
,
Concat
(
1
));
else
util
::
myThrow
(
fmt
::
format
(
"unknown sumodule type '{}'"
,
subModuleType
));
}
catch
(
std
::
exception
&
e
)
{
util
::
myThrow
(
fmt
::
format
(
"{} in '{}'"
,
e
.
what
(),
definition
));}
...
...
torch_modules/src/RawInputModule.cpp
View file @
ed05ee4a
...
...
@@ -27,6 +27,8 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def
myModule
=
register_module
(
"myModule"
,
LSTM
(
inSize
,
outSize
,
options
));
else
if
(
subModuleType
==
"GRU"
)
myModule
=
register_module
(
"myModule"
,
GRU
(
inSize
,
outSize
,
options
));
else
if
(
subModuleType
==
"Concat"
)
myModule
=
register_module
(
"myModule"
,
Concat
(
inSize
));
else
util
::
myThrow
(
fmt
::
format
(
"unknown sumodule type '{}'"
,
subModuleType
));
...
...
torch_modules/src/SplitTransModule.cpp
View file @
ed05ee4a
...
...
@@ -26,6 +26,8 @@ SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, con
myModule
=
register_module
(
"myModule"
,
LSTM
(
inSize
,
outSize
,
options
));
else
if
(
subModuleType
==
"GRU"
)
myModule
=
register_module
(
"myModule"
,
GRU
(
inSize
,
outSize
,
options
));
else
if
(
subModuleType
==
"Concat"
)
myModule
=
register_module
(
"myModule"
,
Concat
(
inSize
));
else
util
::
myThrow
(
fmt
::
format
(
"unknown sumodule type '{}'"
,
subModuleType
));
...
...
torch_modules/src/UppercaseRateModule.cpp
View file @
ed05ee4a
...
...
@@ -30,6 +30,8 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st
myModule
=
register_module
(
"myModule"
,
LSTM
(
1
,
outSize
,
options
));
else
if
(
subModuleType
==
"GRU"
)
myModule
=
register_module
(
"myModule"
,
GRU
(
1
,
outSize
,
options
));
else
if
(
subModuleType
==
"Concat"
)
myModule
=
register_module
(
"myModule"
,
Concat
(
1
));
else
util
::
myThrow
(
fmt
::
format
(
"unknown sumodule type '{}'"
,
subModuleType
));
}
catch
(
std
::
exception
&
e
)
{
util
::
myThrow
(
fmt
::
format
(
"{} in '{}'"
,
e
.
what
(),
definition
));}
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment