Skip to content
Snippets Groups Projects
Rl.py 1.14 KiB
Newer Older
Maxime Petit's avatar
Maxime Petit committed
import random
import torch

################################################################################
class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, transition):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = transition
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

################################################################################

################################################################################

def selectAction(network, state, ts):
    sample = random.random()
    if sample > .2:
        with torch.no_grad():
            return ts[max(torch.nn.functional.softmax(network(state), dim=1))].name

    else:
        return ts[random.randrange(len(ts))].name


################################################################################