Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Franck Dary
macaon
Commits
2d940e47
Commit
2d940e47
authored
Jun 29, 2020
by
Franck Dary
Browse files
Added cnn
parent
567e4969
Changes
4
Hide whitespace changes
Inline
Side-by-side
torch_modules/include/CNN.hpp
View file @
2d940e47
...
...
@@ -2,22 +2,21 @@
#define CNN__H
#include
<torch/torch.h>
#include
"MyModule.hpp"
class
CNNImpl
:
public
torch
::
nn
::
Module
class
CNNImpl
:
public
My
Module
{
private
:
std
::
vector
<
int
>
windowSizes
;
std
::
vector
<
torch
::
nn
::
Conv2d
>
CNNs
;
int
nbFilters
;
int
elemen
tSize
;
std
::
vector
<
int
>
windowSizes
{
2
,
3
}
;
int
outpu
tSize
;
public
:
CNNImpl
(
std
::
vector
<
int
>
window
Size
s
,
int
nbFilters
,
int
elementSize
);
CNNImpl
(
int
input
Size
,
int
outputSize
,
ModuleOptions
options
);
torch
::
Tensor
forward
(
torch
::
Tensor
input
);
std
::
size_t
getOutputSize
();
int
getOutputSize
(
int
sequenceLength
);
};
TORCH_MODULE
(
CNN
);
...
...
torch_modules/include/HistoryModule.hpp
View file @
2d940e47
...
...
@@ -6,6 +6,7 @@
#include
"MyModule.hpp"
#include
"LSTM.hpp"
#include
"GRU.hpp"
#include
"CNN.hpp"
#include
"Concat.hpp"
class
HistoryModuleImpl
:
public
Submodule
...
...
torch_modules/src/CNN.cpp
View file @
2d940e47
#include
"CNN.hpp"
#include
"fmt/core.h"
CNNImpl
::
CNNImpl
(
std
::
vector
<
int
>
windowSizes
,
int
nbFilters
,
int
elementSize
)
:
windowSizes
(
windowSizes
),
nbFilters
(
nbFilters
),
elementSize
(
elementSize
)
CNNImpl
::
CNNImpl
(
int
inputSize
,
int
outputSize
,
ModuleOptions
options
)
:
outputSize
(
outputSize
)
{
for
(
auto
&
windowSize
:
windowSizes
)
{
std
::
string
moduleName
=
fmt
::
format
(
"cnn_window_{}"
,
windowSize
);
CNNs
.
emplace_back
(
register_module
(
moduleName
,
torch
::
nn
::
Conv2d
(
torch
::
nn
::
Conv2dOptions
(
1
,
nbFilters
,
torch
::
ExpandingArray
<
2
>
({
windowSize
,
elementSize
})).
padding
({
windowSize
-
1
,
0
}))));
auto
kernel
=
torch
::
ExpandingArray
<
2
>
({
windowSize
,
inputSize
});
auto
opts
=
torch
::
nn
::
Conv2dOptions
(
1
,
outputSize
,
kernel
).
padding
({
windowSize
-
1
,
0
});
CNNs
.
emplace_back
(
register_module
(
moduleName
,
torch
::
nn
::
Conv2d
(
opts
)));
}
}
torch
::
Tensor
CNNImpl
::
forward
(
torch
::
Tensor
input
)
{
std
::
vector
<
torch
::
Tensor
>
windows
;
input
=
input
.
unsqueeze
(
1
);
for
(
unsigned
int
i
=
0
;
i
<
CNNs
.
size
();
i
++
)
{
auto
convOut
=
torch
::
relu
(
CNNs
[
i
](
input
).
squeeze
(
-
1
)
)
;
auto
pooled
=
torch
::
max_pool1d
(
convOut
,
convOut
.
size
(
2
));
auto
convOut
=
CNNs
[
i
](
input
).
squeeze
(
-
1
);
auto
pooled
=
torch
::
max_pool1d
(
convOut
,
convOut
.
size
(
-
1
));
windows
.
emplace_back
(
pooled
);
}
auto
cnnOut
=
torch
::
cat
(
windows
,
2
);
cnnOut
=
cnnOut
.
view
({
cnnOut
.
size
(
0
),
-
1
});
return
cnnOut
;
return
torch
::
cat
(
windows
,
-
1
).
view
({
input
.
size
(
0
),
-
1
});
}
std
::
size_
t
CNNImpl
::
getOutputSize
()
in
t
CNNImpl
::
getOutputSize
(
int
)
{
return
windowSizes
.
size
()
*
nbFilters
;
return
outputSize
*
windowSizes
.
size
();
}
torch_modules/src/HistoryModule.cpp
View file @
2d940e47
...
...
@@ -26,6 +26,8 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin
myModule
=
register_module
(
"myModule"
,
LSTM
(
inSize
,
outSize
,
options
));
else
if
(
subModuleType
==
"GRU"
)
myModule
=
register_module
(
"myModule"
,
GRU
(
inSize
,
outSize
,
options
));
else
if
(
subModuleType
==
"CNN"
)
myModule
=
register_module
(
"myModule"
,
CNN
(
inSize
,
outSize
,
options
));
else
if
(
subModuleType
==
"Concat"
)
myModule
=
register_module
(
"myModule"
,
Concat
(
inSize
));
else
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment