Skip to content
Snippets Groups Projects
Commit e7740697 authored by Balthazar Casale's avatar Balthazar Casale
Browse files

Add approx_based.py

parent 7bb6ac11
No related branches found
No related tags found
No related merge requests found
"""
Contain approximation based method to solve the bipartite entanglement detection problem
"""
from BEData.models.labels import Label
import numpy as np
class MLModel :
"""
use a machine learning model (signature sklearn)
"""
def __init__(self, model, representation):
self.model = model
self.representation = representation
def predict(self, states, infos={}):
states, _ = self.representation(states)
return self.model.predict(states), {}
class DistToSep:
"""
Check if the state is at a certain distance from a separable approximation
"""
def __init__(self, dist_threshold, sep_key=None, sep_mthd=None):
self.dist_threshold = dist_threshold
self.sep_key = sep_key
self.sep_mthd = sep_mthd
def predict(self, states, infos={}):
aprx = None
inf_aprx = {}
return_aprx = False
if self.sep_key is not None :
aprx = infos[self.sep_key]
elif self.sep_mthd is not None :
return_aprx = True
aprx, inf_aprx = self.sep_mthd(states, infos)
y = np.full(len(states), Label.ENT)
y[np.linalg.norm(aprx - states, axis=(1,2)) < self.dist_threshold] = Label.SEP
if return_aprx :
return y, {'aprx' : aprx, **inf_aprx}
else :
return y, {}
class WitQuality:
"""
Check if the approximate witness for the state is of good enough quality.
Can operate either with entanglement witnesses (data_type = 'wit') or separable approximation (data_type = 'approx')
"""
def __init__(self, min_score, sep_test_set, data_type = 'wit', data_key=None, data_mthd=None, return_scores = False):
self.min_score = min_score
self.test_set = sep_test_set
self.data_type = data_type
self.data_key = data_key
self.data_mthd = data_mthd
self.return_scores = return_scores
def predict(self, states, infos={}):
dim = np.product(states.dims)
data = None
inf_data = {}
return_data = False
if self.data_key is not None :
data = infos[self.data_key]
elif self.data_mthd is not None :
data, inf_data = self.data_mthd(states, infos)
if self.data_type == 'approx' :
C = np.trace(data @ (data - states), axis1=1, axis2=2).real
wits = (data - states) - C[:,None,None] * np.full(states.shape, np.eye(dim))
wits /= np.trace(wits, axis1=1, axis2=2)[:,None,None]
else :
wits = data
resp = np.full(len(states), True)
# tr(W_rho rho) < 0
resp = np.logical_and(resp, np.trace(wits @ states, axis1=1, axis2=2).real < 0)
# tr(W_rho sigma) >= 0 for vast majority of separables sigma
scores = np.zeros(len(states))
for i in range(len(wits)) :
scores[i] = np.average(np.trace(np.matmul(wits[i], self.test_set), axis1=1, axis2=2).real >= 0)
resp = np.logical_and(resp, scores > self.min_score)
if self.return_scores :
inf_wits = {'wit_score' : scores, **inf_data}
if return_data :
inf_wits = {'wit' : wits, **inf_data}
y = np.full(len(states), Label.SEP)
y[resp] = Label.ENT
return y, inf_data
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment