import sys
import random
import torch
import torch.nn.functional as F
from Util import getDevice

################################################################################
class ReplayMemory() :
  def __init__(self, capacity, stateSize) :
    self.capacity = capacity
    self.states = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
    self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
    self.actions = torch.zeros(capacity, 1, dtype=torch.long, device=getDevice())
    self.rewards = torch.zeros(capacity, 1, device=getDevice())
    self.noNewStates = torch.zeros(capacity, dtype=torch.bool, device=getDevice())
    self.position = 0
    self.nbPushed = 0

  def push(self, state, action, newState, reward) :
    self.states[self.position] = state
    self.actions[self.position] = action
    if newState is not None :
      self.newStates[self.position] = newState
    self.noNewStates[self.position] = newState is None
    self.rewards[self.position] = reward 
    self.position = (self.position + 1) % self.capacity
    self.nbPushed += 1

  def sample(self, batchSize) :
    start = random.randint(0, len(self)-batchSize)
    end = start+batchSize
    return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.noNewStates[start:end], self.rewards[start:end]

  def __len__(self):
    return min(self.nbPushed, self.capacity)
################################################################################

################################################################################
def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) :
  sample = random.random()
  if sample < probaRandom :
    return ts[random.randrange(len(ts))]
  elif sample < probaRandom+probaOracle :
    candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
    return candidates[0][1] if len(candidates) > 0 else None
  else :
    with torch.no_grad() :
      output = network(torch.stack([state]))
      predIndex = int(torch.argmax(output))
      return ts[predIndex]
################################################################################

################################################################################
def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) :
  if len(memory) < batchSize :
    return 0.0

  states, actions, nextStates, noNextStates, rewards = memory.sample(batchSize)

  predictedQ = policy_net(states).gather(1, actions)
  nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0)
  nextQ = torch.transpose(nextQ, 0, 1)
  nextQ[noNextStates] = 0.0

  expectedReward = gamma*nextQ + rewards

  loss = F.smooth_l1_loss(predictedQ, expectedReward)
  optimizer.zero_grad()
  loss.backward()
  for param in policy_net.parameters() :
    if param.grad is not None :
      param.grad.data.clamp_(-1, 1)
  optimizer.step()

  return float(loss)
################################################################################

################################################################################
def rewarding(appliable, config, action, missingLinks, funcname):
  return globals()["reward"+funcname](appliable, config, action, missingLinks)
################################################################################

################################################################################
def rewardA(appliable, config, action, missingLinks):
  if appliable:
    if "BACK" not in action.name :
      reward = -1.0*action.getOracleScore(config, missingLinks)
    else :
      back = int(action.name.split()[-1])
      error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0]
      last_error = error_in_pop[-1] if len(error_in_pop) > 0 else 0
      reward = last_error - back
  else:
    reward = -3.0
  return reward
################################################################################

################################################################################
def rewardB(appliable, config, action, missingLinks):
  if appliable:
    if "BACK" not in action.name :
      reward = 1.0 - action.getOracleScore(config, missingLinks)
    else :
      back = int(action.name.split()[-1])
      error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0]
      last_error = error_in_pop[-1] if len(error_in_pop) > 0 else 0
      reward = last_error - back
  else:
    reward = -3.0
  return reward
################################################################################

################################################################################
def rewardC(appliable, config, action, missingLinks):
  if appliable:
    if "BACK" not in action.name :
      reward = -action.getOracleScore(config, missingLinks)
    else :
      back = int(action.name.split()[-1])
      error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0]
      canceledRewards = [h[3] for h in config.historyPop[-back:]]
      reward = -sum(canceledRewards)
  else:
    reward = -3.0
  return reward
################################################################################

################################################################################
def rewardD(appliable, config, action, missingLinks):
  if appliable:
    if "BACK" not in action.name :
      reward = -action.getOracleScore(config, missingLinks)
    else :
      back = int(action.name.split()[-1])
      error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0]
      canceledRewards = [h[3] for h in config.historyPop[-back:]]
      reward = -sum(canceledRewards) - 1
  else:
    reward = -3.0
  return reward
################################################################################

################################################################################
def rewardE(appliable, config, action, missingLinks):
  if appliable:
    if "BACK" not in action.name :
      reward = -action.getOracleScore(config, missingLinks)
    else :
      reward = -0.5
  else:
    reward = -3.0
  return reward
################################################################################