Skip to content
Snippets Groups Projects
Commit a3413da0 authored by Paul Best's avatar Paul Best
Browse files

semi sup umap

parent a6f7fa15
No related branches found
No related tags found
No related merge requests found
...@@ -32,6 +32,7 @@ parser.add_argument('-channel', type=int, default=0) ...@@ -32,6 +32,7 @@ parser.add_argument('-channel', type=int, default=0)
parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)") parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)")
parser.add_argument('-min_sample', type=int, default=3, help='Used for HDBSCAN clustering.') parser.add_argument('-min_sample', type=int, default=3, help='Used for HDBSCAN clustering.')
parser.add_argument('-eps', type=float, default=0.01, help='Used for HDBSCAN clustering.') parser.add_argument('-eps', type=float, default=0.01, help='Used for HDBSCAN clustering.')
parser.add_argument('-semi_sup_umap', type=str, default=None, help='Name of a column in the detections file with labels to guide the umap compression (see target_metric in umap documentation)')
args = parser.parse_args() args = parser.parse_args()
df = pd.read_csv(args.detections) df = pd.read_csv(args.detections)
...@@ -42,6 +43,8 @@ frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel ...@@ -42,6 +43,8 @@ frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel
args.sampleDur += (2 if args.medfilt else 0) args.sampleDur += (2 if args.medfilt else 0)
if args.umap_ndim == 2: if args.umap_ndim == 2:
if args.semi_sup_umap:
umap_ = umap.UMAP(n_jobs=-1).fit_transform(encodings['encodings'], y=df.loc[idxs, args.semi_sup_umap])
df.loc[idxs, 'umap_x'] = umap_[:,0] df.loc[idxs, 'umap_x'] = umap_[:,0]
df.loc[idxs, 'umap_y'] = umap_[:,1] df.loc[idxs, 'umap_y'] = umap_[:,1]
...@@ -101,7 +104,7 @@ if args.umap_ndim == 2: ...@@ -101,7 +104,7 @@ if args.umap_ndim == 2:
plt.show() plt.show()
else : else :
umap_ = umap.UMAP(n_jobs=-1, n_components=args.umap_ndim).fit_transform(embeddings) umap_ = umap.UMAP(n_jobs=-1, n_components=args.umap_ndim).fit_transform(embeddings, y=df[args.semi_sup_umap] if args.semi_sup_umap else None)
df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size, df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size,
min_samples=args.min_sample, min_samples=args.min_sample,
core_dist_n_jobs=-1, core_dist_n_jobs=-1,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment