import random import torch ################################################################################ class ReplayMemory(object): def __init__(self, capacity): self.capacity = capacity self.memory = [] self.position = 0 def push(self, transition): """Saves a transition.""" if len(self.memory) < self.capacity: self.memory.append(None) self.memory[self.position] = transition self.position = (self.position + 1) % self.capacity def sample(self, batch_size): return random.sample(self.memory, batch_size) def __len__(self): return len(self.memory) ################################################################################ ################################################################################ def selectAction(network, state, ts): sample = random.random() if sample > .2: with torch.no_grad(): return ts[max(torch.nn.functional.softmax(network(state), dim=1))].name else: return ts[random.randrange(len(ts))].name ################################################################################