From e236fcd2fb15ee09100fa939fe261dc07a61ab2f Mon Sep 17 00:00:00 2001 From: Balthazar Casale <balthazar.casale@lis-lab.fr> Date: Fri, 9 Jun 2023 18:10:56 +0200 Subject: [PATCH] Update 2 files - /src/models/approx_based.py - /README.md --- README.md | 8 +++++--- src/models/approx_based.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 54de1ab..d5fb6ec 100644 --- a/README.md +++ b/README.md @@ -24,12 +24,13 @@ We give a typical use case in the following snipped of code : ```python from types import save_dmstack, load_dmstack from pipeline import * +from samplers.mixed import RandomInduced from models.criteria import PPT from models.approx_based import DistToSep from transformers.sep_approximation import FrankWolfe states, infos = Pipeline([ - ('sample', InducedMeasure(k_params=[25]).states), # induced measure of parameter 25 + ('sample', RandomInduced(k_params=[25]).states), # induced measure of parameter 25 ('ppt only', select(PPT.is_respected, True)), # respecting the PPT criterion ('fw', add(FrankWolfe(1000).approximation, key = 'approx'), # compute the sep approx. ('sel ent', select(DistToSep(0.01, sep_key = 'fw__approx').predict, Label.ENT)) @@ -55,6 +56,7 @@ def sampler(n_states : int, dims : list[int]) -> DMStack, dict ``` the following samplers can be found in the library : +- samplers.utils.FromSet - samplers.pure.RandomHaar - samplers.mixed.RandomInduced - samplers.mixed.RandomBures @@ -71,8 +73,8 @@ def transformer(states : DMStack, infos : dict) -> DMStack, dict the following transformers can be found in the library : - transformers.sep_approximations.FrankWolfe -- transformers.real_representation.GellMann -- transformer.real_representation.Measures +- transformers.representations.GellMann +- transformer.representations.Measures ### model diff --git a/src/models/approx_based.py b/src/models/approx_based.py index 8e98738..e82a4ab 100644 --- a/src/models/approx_based.py +++ b/src/models/approx_based.py @@ -7,6 +7,16 @@ from ..types import Label import numpy as np +class MlModel : + """ + Use a machine learning model (sklearn) as model + """ + def __init__(self, model) : + self.model = model + + def predict(self, states, infos={}): + return self.model.predict(state), {} + class DistToSep: """ Distance from a separable approximation -- GitLab