Skip to content
Snippets Groups Projects
Select Git revision
  • 220ac3c5ccdfa1936ccfb5e88724115e383fa0f5
  • master default protected
2 results

round.py

Blame
  • round.py 1.59 KiB
    from abc import ABC
    from numbers import Number
    
    import torch
    from torch.autograd import Variable
    from distributions.distribution import Distribution
    from distributions.utils import broadcast_all
    
    
    class Round(Distribution, ABC):
        has_enumerate_support = True
    
        def __init__(self, probs):
            self.probs, = broadcast_all(probs)
            if isinstance(probs, Number):
                batch_shape = torch.Size()
            else:
                batch_shape = self.probs.size()
            super(Round, self).__init__(batch_shape)
    
        def sample(self, sample_shape=torch.Size()):
            shape = self._extended_shape(sample_shape)
            return torch.round(self.probs.expand(shape))
    
        def log_prob(self, value):
            self._validate_log_prob_arg(value)
            param_shape = value.size()
            probs = self.probs.expand(param_shape)
            # compute the log probabilities for 0 and 1
            log_pmf = (torch.stack([1 - probs, probs], dim=-1)).log()
            # evaluate using the values
            return log_pmf.gather(-1, value.unsqueeze(-1).long()).squeeze(-1)
    
        def entropy(self):
            p = torch.stack([self.probs, 1.0 - self.probs])
            p_log_p = torch.log(p) * p
            p_log_p[p == 0] = 0
            return -p_log_p.sum(0)
    
        def enumerate_support(self):
            values = torch.arange(2).long()
            values = values.view((-1,) + (1,) * len(self._batch_shape))
            values = values.expand((-1,) + self._batch_shape)
            if self.probs.is_cuda:
                values = values.cuda(self.probs.get_device())
            if isinstance(self.probs, Variable):
                values = Variable(values)
            return values