Skip to content
Snippets Groups Projects
Commit 9d3a3e6d authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

cvxopt_forced

parent e0a9330f
No related branches found
No related tags found
No related merge requests found
...@@ -186,7 +186,7 @@ class ColumnGenerationClassifier(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -186,7 +186,7 @@ class ColumnGenerationClassifier(BaseEstimator, ClassifierMixin, BaseBoost):
if previous_w is not None: if previous_w is not None:
qp.initial_values = np.append(previous_w, [0]) qp.initial_values = np.append(previous_w, [0])
try: # try:
solver_result = qp.solve(abstol=1e-10, reltol=1e-10, feastol=1e-10, return_all_information=True) solver_result = qp.solve(abstol=1e-10, reltol=1e-10, feastol=1e-10, return_all_information=True)
w = np.asarray(np.array(solver_result['x']).T[0])[:n_hypotheses] w = np.asarray(np.array(solver_result['x']).T[0])[:n_hypotheses]
...@@ -200,17 +200,17 @@ class ColumnGenerationClassifier(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -200,17 +200,17 @@ class ColumnGenerationClassifier(BaseEstimator, ClassifierMixin, BaseBoost):
self.dual_constraint_rhs = dual_variables[-1] self.dual_constraint_rhs = dual_variables[-1]
# logging.info('Updating dual constraint rhs: {}'.format(self.dual_constraint_rhs)) # logging.info('Updating dual constraint rhs: {}'.format(self.dual_constraint_rhs))
except: # except:
logging.warning('QP Solving failed at iteration {}.'.format(n_hypotheses)) # logging.warning('QP Solving failed at iteration {}.'.format(n_hypotheses))
if previous_w is not None: # if previous_w is not None:
w = np.append(previous_w, [0]) # w = np.append(previous_w, [0])
else: # else:
w = np.array([1.0 / n_hypotheses] * n_hypotheses) # w = np.array([1.0 / n_hypotheses] * n_hypotheses)
#
if previous_alpha is not None: # if previous_alpha is not None:
alpha = previous_alpha # alpha = previous_alpha
else: # else:
alpha = self._initialize_alphas(n_examples) # alpha = self._initialize_alphas(n_examples)
return w, alpha return w, alpha
......
...@@ -13,6 +13,12 @@ def testVersions(): ...@@ -13,6 +13,12 @@ def testVersions():
except ImportError: except ImportError:
raise raise
try:
import cvxopt
except ImportError:
isUpToDate = False
toInstall.append("cvxopt")
try: try:
import pyscm import pyscm
except ImportError: except ImportError:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment