Newer
Older
Franck Dary
committed
import torch
import torch.nn as nn
import torch.nn.functional as F
Franck Dary
committed
import Features
################################################################################
class BaseNet(nn.Module):
Franck Dary
committed
def __init__(self, dicts, outputSize) :
Franck Dary
committed
self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
Franck Dary
committed
self.featureFunction = "b.-2 b.-1 b.0 b.1 b.2 s.0 s.1 s.2 s.0.0 s.0.-1 s.0.1 s.1.0 s.1.-1 s.1.1 s.2.0 s.2.-1 s.2.1"
self.columns = ["UPOS"]
Franck Dary
committed
self.inputSize = len(self.columns)*len(self.featureFunction.split())
Franck Dary
committed
for name in dicts.dicts :
self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
Franck Dary
committed
self.fc1 = nn.Linear(self.inputSize * self.embSize, 1600)
self.fc2 = nn.Linear(1600, outputSize)
self.dropout = nn.Dropout(0.3)
Franck Dary
committed
self.apply(self.initWeights)
Franck Dary
committed
x = self.dropout(getattr(self, "emb_"+"UPOS")(x).view(x.size(0), -1))
x = F.relu(self.dropout(self.fc1(x)))
Franck Dary
committed
def currentDevice(self) :
return self.dummyParam.device
def initWeights(self,m) :
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
Franck Dary
committed
def extractFeatures(self, dicts, config) :
return Features.extractColsFeatures(dicts, config, self.featureFunction, self.columns)
################################################################################