Skip to content
Snippets Groups Projects
Commit 35d0e4a0 authored by Raphael's avatar Raphael
Browse files

bug fix hmm

parent 9487015c
No related branches found
No related tags found
1 merge request!13Draft: Develop
......@@ -30,41 +30,18 @@ def split_trajectories(feature_seq, label_seq, n_classes):
return result, sequence_list
class GMMHMMClassifier:
def __init__(self, nb_states):
self.n_features = 0
if type(nb_states) is not list:
self.nb_states = np.array([nb_states])
else:
self.nb_states = np.array(nb_states)
self.hmms = []
self.n_classes = len(self.nb_states)
for i, nb_state in enumerate(self.nb_states):
self.hmms.append(GaussianHMM(n_components=nb_state, covariance_type='full'))
self.hmm = GaussianHMM(n_components=sum(self.nb_states), covariance_type='full', init_params='', n_iter=100)
self.predict_dictionary = {}
self.predictor = []
count = 0
for i, ii in enumerate(self.nb_states):
self.predictor.append({})
for j in range(ii):
self.predictor[i][j] = count
self.predict_dictionary[count] = i
count += 1
def fit(self, x, y):
sequences = {i: [] for i in range(self.n_classes)}
self.n_features = x[0].shape[1]
def get_sequences(x, y, n_classes):
sequences = {i: [] for i in range(n_classes)}
for feature_seq, label_seq in zip(x, y):
split_seq, _ = split_trajectories(feature_seq, label_seq, self.n_classes)
split_seq, _ = split_trajectories(feature_seq, label_seq, n_classes)
for key in sequences.keys():
sequences[key] += split_seq[key]
return sequences
def fit_hmmms(self, sequences):
for i, seqs in sequences.items():
self.hmms[i].n_features = self.n_features
if sum([np.array(s).size for s in seqs]) > sum(self.hmms[i]._get_n_fit_scalars_per_param().values()):
......@@ -76,21 +53,26 @@ class GMMHMMClassifier:
else:
self.hmms[i] = None
def get_predictions(x, y, hmms, predictor, n_classes):
predict = []
for feature_seq, label_seq in zip(x, y):
_, sequences_list = split_trajectories(feature_seq, label_seq, self.n_classes)
_, sequences_list = split_trajectories(feature_seq, label_seq, n_classes)
pred = np.array([])
for label, seq in sequences_list:
if self.hmms[label] is not None:
_, state_sequence = self.hmms[label].decode(np.array(seq), [len(seq)])
pred = np.append(pred, [self.predictor[label][i] for i in state_sequence])
if hmms[label] is not None:
_, state_sequence = hmms[label].decode(np.array(seq), [len(seq)])
pred = np.append(pred, [predictor[label][i] for i in state_sequence])
if len(pred) != 0:
predict.append(pred)
return predict
start = np.zeros(sum(self.nb_states))
T_mat = np.zeros((sum(self.nb_states), sum(self.nb_states)))
def get_new_hmm_values(nb_states, predict):
start = np.zeros(sum(nb_states))
T_mat = np.zeros((sum(nb_states), sum(nb_states)))
prior = -1
count = np.zeros(sum(self.nb_states))
count = np.zeros(sum(nb_states))
for pred in predict:
start[int(pred[0])] += 1
......@@ -100,36 +82,85 @@ class GMMHMMClassifier:
count[prior] += 1
prior = int(p)
for i in range(sum(self.nb_states)):
for j in range(sum(self.nb_states)):
for i in range(sum(nb_states)):
for j in range(sum(nb_states)):
if T_mat[i][j] > 0:
T_mat[i][j] = T_mat[i][j] / count[i]
self.hmm.startprob_ = start / sum(start)
self.hmm.transmat_ = T_mat
for i, value in enumerate(self.hmm.transmat_.sum(axis=1)):
for i, value in enumerate(T_mat.sum(axis=1)):
if value == 0:
self.hmm.transmat_[i][i] = 1.0
T_mat[i][i] = 1.0
return start, T_mat, count
def get_means_and_covars(hmms, nb_states, nb_features, degens):
means = []
covars = []
for i, model in enumerate(self.hmms):
if self.hmms[i] is not None:
for i, model in enumerate(hmms):
if hmms[i] is not None:
means.append(model.means_)
covars.append(model.covars_)
else:
means.append(np.zeros((self.nb_states[i], x[0].shape[1])))
covars.append(np.stack([make_spd_matrix(x[0].shape[1])
for _ in range(self.nb_states[i])], axis=0))
means.append(np.zeros((nb_states[i], nb_features)))
covars.append(np.stack([make_spd_matrix(nb_features)
for _ in range(nb_states[i])], axis=0))
means = np.concatenate(means)
covars = np.concatenate(covars)
for n, cv in enumerate(covars):
if count[n] <= 3:
if degens[n] and np.any(linalg.eigvalsh(cv) > 0):
covars[n] = np.identity(cv.shape[0])
if not np.allclose(cv, cv.T) or np.any(linalg.eigvalsh(cv) <= 0):
limit = 0
while (not np.allclose(cv, cv.T) or np.any(linalg.eigvalsh(cv) <= 0)):
covars[n] += np.identity(cv.shape[0]) * 10 ** -15
if limit > 100:
covars[n] = np.identity(cv.shape[0])
break
return means, covars
class GMMHMMClassifier:
def __init__(self, nb_states, max_iter=100,verbose=False):
self.n_features = 0
if type(nb_states) is not list:
self.nb_states = np.array([nb_states])
else:
self.nb_states = np.array(nb_states)
self.degen_ = [False for _ in range(sum(self.nb_states))]
self.hmms = []
self.n_classes = len(self.nb_states)
for i, nb_state in enumerate(self.nb_states):
self.hmms.append(GaussianHMM(n_components=nb_state, covariance_type='full', verbose=verbose, n_iter=max_iter))
self.hmm = GaussianHMM(n_components=sum(self.nb_states), covariance_type='full', init_params='', n_iter=max_iter)
self.predict_dictionary = {}
self.predictor = []
count = 0
for i, ii in enumerate(self.nb_states):
self.predictor.append({})
for j in range(ii):
self.predictor[i][j] = count
self.predict_dictionary[count] = i
count += 1
def fit(self, x, y):
self.n_features = x[0].shape[1]
sequences = get_sequences(x, y, self.n_classes)
fit_hmmms(self, sequences)
predict = get_predictions(x, y, self.hmms, self.predictor, self.n_classes)
start, T_mat, count = get_new_hmm_values(self.nb_states, predict)
self.get_degens(count)
self.hmm.startprob_ = start / sum(start)
self.hmm.transmat_ = T_mat
means, covars = get_means_and_covars(self.hmms, self.nb_states, self.n_features, self.degen_)
self.hmm.means_ = means
self.hmm.covars_ = covars
......@@ -146,6 +177,11 @@ class GMMHMMClassifier:
return self.hmm.predict(X_all, lenghts)
def get_degens(self, count):
for i, c in enumerate(count):
if c < self.n_features:
self.degen_[i] = True
@jit(nopython=True)
def hmm_probabilities(predict, nb_states):
n_states = nb_states.sum()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment