From e236fcd2fb15ee09100fa939fe261dc07a61ab2f Mon Sep 17 00:00:00 2001
From: Balthazar Casale <balthazar.casale@lis-lab.fr>
Date: Fri, 9 Jun 2023 18:10:56 +0200
Subject: [PATCH] Update 2 files

- /src/models/approx_based.py
- /README.md
---
 README.md                  |  8 +++++---
 src/models/approx_based.py | 10 ++++++++++
 2 files changed, 15 insertions(+), 3 deletions(-)

diff --git a/README.md b/README.md
index 54de1ab..d5fb6ec 100644
--- a/README.md
+++ b/README.md
@@ -24,12 +24,13 @@ We give a typical use case in the following snipped of code :
 ```python
 from types import save_dmstack, load_dmstack
 from pipeline import *
+from samplers.mixed import RandomInduced
 from models.criteria import PPT
 from models.approx_based import DistToSep
 from transformers.sep_approximation import FrankWolfe
 
 states, infos = Pipeline([
-	('sample', InducedMeasure(k_params=[25]).states),		# induced measure of parameter 25
+	('sample', RandomInduced(k_params=[25]).states),		# induced measure of parameter 25
 	('ppt only', select(PPT.is_respected, True)),			# respecting the PPT criterion
 	('fw', add(FrankWolfe(1000).approximation, key = 'approx'), # compute the sep approx.
 	('sel ent', select(DistToSep(0.01, sep_key = 'fw__approx').predict, Label.ENT))
@@ -55,6 +56,7 @@ def sampler(n_states : int, dims : list[int]) -> DMStack, dict
 ```
 the following samplers can be found in the library :
 
+- samplers.utils.FromSet
 - samplers.pure.RandomHaar
 - samplers.mixed.RandomInduced
 - samplers.mixed.RandomBures
@@ -71,8 +73,8 @@ def transformer(states : DMStack, infos : dict) -> DMStack, dict
 the following transformers can be found in the library :
 
 - transformers.sep_approximations.FrankWolfe
-- transformers.real_representation.GellMann
-- transformer.real_representation.Measures
+- transformers.representations.GellMann
+- transformer.representations.Measures
 
 ### model
 
diff --git a/src/models/approx_based.py b/src/models/approx_based.py
index 8e98738..e82a4ab 100644
--- a/src/models/approx_based.py
+++ b/src/models/approx_based.py
@@ -7,6 +7,16 @@ from ..types import Label
 import numpy as np
 
 
+class MlModel :
+    """
+    Use a machine learning model (sklearn) as model
+    """
+    def __init__(self, model) :
+        self.model = model
+
+    def predict(self, states, infos={}):
+        return self.model.predict(state), {}
+
 class DistToSep:
     """
     Distance from a separable approximation
-- 
GitLab