Skip to content
Snippets Groups Projects
Commit e045a789 authored by Luc Giffon's avatar Luc Giffon
Browse files

implement loop nn_omp

parent 2a24aacb
No related branches found
No related tags found
1 merge request!24Resolve "non negative omp"
from scipy.optimize import nnls
import numpy as np
def nn_omp(T, y, max_iter, intermediate_solutions_sizes=None):
"""
Ref: Sparse Non-Negative Solution of a
Linear System of Equations is Unique
T: (N x L)
y: (N x 1)
max_iter: the max number of iteration. If intermediate_solutions_sizes is None. Return the max_iter-sparse solution.
intermediate_solutions_sizes: a list of the other returned intermediate solutions than with max_iter (they are returned in a list with same indexes)
"""
if intermediate_solutions_sizes is None:
intermediate_solutions_sizes = [max_iter]
# elif max_iter not in intermediate_solutions_sizes:
# intermediate_solutions_sizes.append(max_iter)
assert all(type(elm) == int for elm in intermediate_solutions_sizes), "All intermediate solution must be size specified as integers."
iter_intermediate_solutions_sizes = iter(intermediate_solutions_sizes)
lst_intermediate_solutions = []
bool_arr_selected_indexes = np.zeros(T.shape[1], dtype=bool)
residual = y
i = 0
next_solution = next(iter_intermediate_solutions_sizes, None)
while i < max_iter and next_solution != None:
print("iter {}".format(i))
dot_products = T.T @ residual
idx_max_dot_product = np.argmax(dot_products)
if dot_products[idx_max_dot_product] <= 0:
print("No other atoms is positively correlated with the residual. End prematurely with {} atoms.".format(i+1))
break
bool_arr_selected_indexes[idx_max_dot_product] = True
tmp_T = T[:, bool_arr_selected_indexes]
sol = nnls(tmp_T, y)[0]
residual = y - tmp_T @ sol
if i+1 == next_solution:
final_vec = np.zeros(T.shape[1])
final_vec[bool_arr_selected_indexes] = sol
lst_intermediate_solutions.append(final_vec)
next_solution = next(iter_intermediate_solutions_sizes, None)
i+=1
if len(lst_intermediate_solutions) == 1:
return lst_intermediate_solutions[-1]
else:
return lst_intermediate_solutions
if __name__ == "__main__":
N = 1000
L = 100
K = 10
T = np.random.rand(N, L)
w_star = np.abs(np.random.rand(L))
T /= np.linalg.norm(T, axis=0)
y = T @ w_star
requested_solutions = list(range(1, L, 10))
solutions = nn_omp(T, y, L, requested_solutions)
for idx_sol, w in enumerate(solutions):
solution = T @ w
non_zero = w.astype(bool)
print(requested_solutions[idx_sol], np.sum(non_zero), np.linalg.norm(solution - y)/np.linalg.norm(y))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment