import sys
import Oracle
from Dicos import Dicos
from Config import Config
from Word import Word
from Mcd import Mcd
from Moves import Moves
from FeatModel import FeatModel
import torch

import numpy as np



def prepareWordBufferForDecode(buffer):
    """Add to every word of the buffer features GOVREF and LABELREF.

    GOVEREF is a copy of feature GOV and LABELREF a copy of LABEL
    GOV and LABEL are set to initialization values
    """
    for word in buffer.array:
        word.setFeat('GOV', str(Word.invalidGov()))
        word.setFeat('LABEL', Word.invalidLabel())


verbose = False
if len(sys.argv) != 7 :
    print('usage:', sys.argv[0], 'mcf_file model_file dicos_file feat_model mcd_file words_limit')
    exit(1)

mcf_file =       sys.argv[1]
model_file =     sys.argv[2]
dicos_file =     sys.argv[3]
feat_model =     sys.argv[4]
mcd_file =       sys.argv[5]
wordsLimit = int(sys.argv[6])

    
sys.stderr.write('reading mcd from file :')
sys.stderr.write(mcd_file)
sys.stderr.write('\n')
mcd = Mcd(mcd_file)

sys.stderr.write('loading dicos\n')
dicos = Dicos(fileName=dicos_file)

moves = Moves(dicos)

sys.stderr.write('reading feature model from file :')
sys.stderr.write(feat_model)
sys.stderr.write('\n')
featModel = FeatModel(feat_model, dicos)

sys.stderr.write('loading model :')
sys.stderr.write(model_file)
sys.stderr.write('\n')
model = load_model(model_file)

inputSize = featModel.getInputSize()
outputSize = moves.getNb()

c = Config(mcf_file, mcd, dicos)
numSent = 0
verbose = False
numWords = 0

while c.getBuffer().readNextSentence()  and numWords < wordsLimit :
    c.getStack().empty()
    prepareWordBufferForDecode(c.getBuffer())
    numWords += c.getBuffer().getLength()

    while True :
        featVec = c.extractFeatVec(featModel)
        inputVector = featModel.buildInputVector(featVec, dicos)
        outputVector = model.predict(inputVector.reshape((1,inputSize)), batch_size=1, verbose=0, steps=None)
        mvt_Code = outputVector.argmax()
        mvt = moves.mvtDecode(mvt_Code)

        if(verbose == True) :
            print("------------------------------------------")
            c.affiche()
            print('predicted move', mvt[0], mvt[1])
            print(mvt, featVec)

        res = c.applyMvt(mvt)
        if not res :
            sys.stderr.write("cannot apply predicted movement\n")
            mvt_type = mvt[0]
            mvt_label = mvt[1]
            if mvt_type != "SHIFT" :
                sys.stderr.write("try to force SHIFT\n")
                res = c.shift()
                if res == False :
                    sys.stderr.write("try to force REDUCE\n")
                    res = c.red()
                    if res == False :
                        sys.stderr.write("abort sentence\n")
                        break
        if(c.isFinal()):
            break
    for i in range(1, c.getBuffer().getLength()):
        w = c.getBuffer().getWord(i)
        w.affiche(mcd)
        print('')
#        print('\t', w.getFeat("GOV"), end='\t')
#        print(w.getFeat("LABEL"))

    numSent += 1
#    if numSent % 10 == 0:
#        print ("Sent : ", numSent)