Skip to content
Snippets Groups Projects
Commit a87dea15 authored by Balthazar Casale's avatar Balthazar Casale
Browse files

Add utils.py

parent 12b8d9e2
No related branches found
No related tags found
No related merge requests found
"""
Contain utility functions and sampler function used in the library
"""
from BEData.types import DMStack
# FUNCTIONS
def kron(m1, m2):
"""
The kronecker product of two stacks of matrices, element wise
"""
return np.array([np.kron(m1[i], m2[i]) for i in range(len(m1))])
def dagger(m):
"""
The conjugate transpose of each matrices
"""
return np.transpose(m, axes=(0,2,1)).conjugate()
def balanced(n_states, n_params):
"""
Compute the number of sample returned for each parameters.
"""
n_per_val = np.full(n_params, int(n_states/n_params))
remaining = n_states - np.sum(n_per_val)
while remaining > 0 :
remaining -= 1
n_per_val[np.random.randint(0, n_params)] += 1
return n_per_val
# SAMPLERS
class RandomGinibre :
"""
Random complex-valued matrix where each entry is a random variable
sampled from the normal distribution
"""
@staticmethod
def matrices(n_mats, dims):
return np.random.randn(n_mats, *dims) + 1j * np.random.randn(n_mats, *dims)
class RandomUnitary :
"""
Random unitary matrices sampled from the Haar measure
"""
def __init__(self, product=False):
self.product = product
def matrices(self, n_mats, dims):
if self.product :
s1 = RandomUnitary().matrices(n_mats, [dims[0]])
s2 = RandomUnitary().matrices(n_mats, [dims[1]])
return kron(s1,s2)
else :
dim = np.product(dims)
return np.linalg.qr(RandomGinibre.matrices(n_mats, [dim, dim]), mode='complete')[0]
class FromSet :
"""
Random subset from an already existing set.
"""
def __init__(self, states, infos):
self.datas = states
self.infos = infos
def states(self, n_states, dims=None):
dims = self.datas.dims
if n_states >= len(self.datas) :
return DMStack(self.datas.copy(), dims), self.infos.copy()
idx = np.arange(len(self.datas))
np.random.shuffle(idx)
idx = idx[:n_states]
states = self.datas.copy()[idx]
infos = self.infos.copy()
for key in infos.keys() :
infos[key] = infos[key][idx]
return DMStack(states, dims), infos
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment