diff --git a/README.md b/README.md index 54de1ab3153533a3155aa751100a4d9ed74386a2..d5fb6eca6553d0141e7f1d29c2fde277e3f3d9c3 100644 --- a/README.md +++ b/README.md @@ -24,12 +24,13 @@ We give a typical use case in the following snipped of code : ```python from types import save_dmstack, load_dmstack from pipeline import * +from samplers.mixed import RandomInduced from models.criteria import PPT from models.approx_based import DistToSep from transformers.sep_approximation import FrankWolfe states, infos = Pipeline([ - ('sample', InducedMeasure(k_params=[25]).states), # induced measure of parameter 25 + ('sample', RandomInduced(k_params=[25]).states), # induced measure of parameter 25 ('ppt only', select(PPT.is_respected, True)), # respecting the PPT criterion ('fw', add(FrankWolfe(1000).approximation, key = 'approx'), # compute the sep approx. ('sel ent', select(DistToSep(0.01, sep_key = 'fw__approx').predict, Label.ENT)) @@ -55,6 +56,7 @@ def sampler(n_states : int, dims : list[int]) -> DMStack, dict ``` the following samplers can be found in the library : +- samplers.utils.FromSet - samplers.pure.RandomHaar - samplers.mixed.RandomInduced - samplers.mixed.RandomBures @@ -71,8 +73,8 @@ def transformer(states : DMStack, infos : dict) -> DMStack, dict the following transformers can be found in the library : - transformers.sep_approximations.FrankWolfe -- transformers.real_representation.GellMann -- transformer.real_representation.Measures +- transformers.representations.GellMann +- transformer.representations.Measures ### model diff --git a/src/models/approx_based.py b/src/models/approx_based.py index 8e987381f037ab7eacadaa27755d6b5ecaf80d10..e82a4aba232ab96deea47a48ac746473f5f46f25 100644 --- a/src/models/approx_based.py +++ b/src/models/approx_based.py @@ -7,6 +7,16 @@ from ..types import Label import numpy as np +class MlModel : + """ + Use a machine learning model (sklearn) as model + """ + def __init__(self, model) : + self.model = model + + def predict(self, states, infos={}): + return self.model.predict(state), {} + class DistToSep: """ Distance from a separable approximation