diff --git a/code/playground/nn_omp.py b/code/playground/nn_omp.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..382cb6ae52e3f7f43e7ff207003379d8bdcd9dfe 100644 --- a/code/playground/nn_omp.py +++ b/code/playground/nn_omp.py @@ -0,0 +1,76 @@ +from scipy.optimize import nnls +import numpy as np + + +def nn_omp(T, y, max_iter, intermediate_solutions_sizes=None): + """ + Ref: Sparse Non-Negative Solution of a +Linear System of Equations is Unique + + T: (N x L) + y: (N x 1) + max_iter: the max number of iteration. If intermediate_solutions_sizes is None. Return the max_iter-sparse solution. + intermediate_solutions_sizes: a list of the other returned intermediate solutions than with max_iter (they are returned in a list with same indexes) + """ + if intermediate_solutions_sizes is None: + intermediate_solutions_sizes = [max_iter] + # elif max_iter not in intermediate_solutions_sizes: + # intermediate_solutions_sizes.append(max_iter) + + assert all(type(elm) == int for elm in intermediate_solutions_sizes), "All intermediate solution must be size specified as integers." + + iter_intermediate_solutions_sizes = iter(intermediate_solutions_sizes) + + lst_intermediate_solutions = [] + bool_arr_selected_indexes = np.zeros(T.shape[1], dtype=bool) + residual = y + i = 0 + next_solution = next(iter_intermediate_solutions_sizes, None) + while i < max_iter and next_solution != None: + print("iter {}".format(i)) + dot_products = T.T @ residual + + idx_max_dot_product = np.argmax(dot_products) + if dot_products[idx_max_dot_product] <= 0: + print("No other atoms is positively correlated with the residual. End prematurely with {} atoms.".format(i+1)) + break + + bool_arr_selected_indexes[idx_max_dot_product] = True + + tmp_T = T[:, bool_arr_selected_indexes] + sol = nnls(tmp_T, y)[0] + residual = y - tmp_T @ sol + + if i+1 == next_solution: + final_vec = np.zeros(T.shape[1]) + final_vec[bool_arr_selected_indexes] = sol + lst_intermediate_solutions.append(final_vec) + next_solution = next(iter_intermediate_solutions_sizes, None) + + i+=1 + + if len(lst_intermediate_solutions) == 1: + return lst_intermediate_solutions[-1] + else: + return lst_intermediate_solutions + +if __name__ == "__main__": + + N = 1000 + L = 100 + K = 10 + + T = np.random.rand(N, L) + w_star = np.abs(np.random.rand(L)) + + T /= np.linalg.norm(T, axis=0) + y = T @ w_star + + requested_solutions = list(range(1, L, 10)) + solutions = nn_omp(T, y, L, requested_solutions) + + for idx_sol, w in enumerate(solutions): + solution = T @ w + non_zero = w.astype(bool) + print(requested_solutions[idx_sol], np.sum(non_zero), np.linalg.norm(solution - y)/np.linalg.norm(y)) +