Skip to content
Snippets Groups Projects
Select Git revision
  • b71be23a22dfcaaaab4317f71b443179ba43a384
  • master default
  • object
  • develop protected
  • private_algos
  • cuisine
  • SMOTE
  • revert-76c4cca5
  • archive protected
  • no_graphviz
  • 0.0.1
11 results

Dataset.py

Blame
  • datasets.py 8.20 KiB
    # -*- coding: utf-8 -*-
    """
    
    .. moduleauthor:: Valentin Emiya
    """
    from pathlib import Path
    
    import numpy as np
    from matplotlib import pyplot as plt
    from scipy.ndimage import \
        binary_opening, binary_closing, generate_binary_structure
    
    from madarrays import Waveform
    from ltfatpy import plotdgtreal
    
    from tffpy.tf_tools import get_signal_params, get_dgt_params, GaborMultiplier
    from tffpy.utils import dgt, db, plot_spectrogram, plot_mask, get_data_path
    
    
    default_data_root_dir = get_data_path()
    default_data_dir = default_data_root_dir / 'data_8000Hz_16384samples'
    
    
    def get_dataset():
        """
        Get dataset for isolated wideband and localized sources before mixing.
    
        Returns
        -------
        dataset : dict
            dataset['wideband'] (resp. dataset['localized']) is a dictionary
            containing the :py:class:`Path` object for all the wideband
            (resp. localized) sounds.
        """
        dataset = dict()
        dataset['wideband'] = {
            x.stem: x
            for x in (default_data_dir / 'wide_band_sources').glob('*.wav')
        }
        dataset['localized'] = {
            x.stem: x
            for x in (default_data_dir / 'localized_sources').glob('*.wav')
        }
        return dataset
    
    
    def get_mix(loc_source, wideband_src, crop=None,
                wb_to_loc_ratio_db=0, win_dur=128 / 8000, win_type='gauss',
                hop_ratio=1/4, n_bins_ratio=4, n_iter_closing=2,
                n_iter_opening=2, delta_mix_db=0, delta_loc_db=30,
                closing_first=True, or_mask=False,
                fig_dir=None, prefix=''):
        """
        Build the mix two sounds and the related time-frequency boolean mask.
    
        Parameters
        ----------
        loc_source : Path
            Localized sound file.
        wideband_src : Path
            Wideband sound file.
        crop : int or None
            If not None, a cropped, centered portion of the sound will be
            extracted with the specified length, in samples.
        wb_to_loc_ratio_db : float
            Wideband source to localized source energy ratio to be adjusted in
            the mix.
        win_dur : float
            Window duration, in seconds.
        win_type : str
            Window name
        hop_ratio : float
            Ratio of the window length that will be set as hop size for the DGT.
        n_bins_ratio : float
            Factor that will be applied to the window length to compute the
            number of bins in the DGT.
        delta_mix_db : float
            Coefficient energy ratio, in dB, between the wideband source and the
            localized source in the mixture in order to select coefficients in
            the mask.
        delta_loc_db : float
            Dynamic range, in dB, for the localized source in order to select
            coefficients in the mask.
        or_mask : bool
            If True, the mask is build by taking the union of the two masks
            obtained using thresholds `delta_mix_db` and `delta_loc_db`. If
            False, the intersection is taken.
        n_iter_closing : int
            Number of successive morphological closings with radius 1 (a.k.a.
            radius of one single closing)
        n_iter_opening : int
            Number of successive morphological openings with radius 1 (a.k.a.
            radius of one single opening)
        closing_first : bool
            If True, morphological closings are applied first, followed by
            openings. If False, the reverse way is used.
        fig_dir : str or Path
            If not None, folder where figures are stored. If None, figures are
            not plotted.
        prefix : str
            If not None, this prefix is used when saving the figures.
    
        Returns
        -------
        x_mix : Waveform
            Mix signal (sum of outputs `x_loc` and `x_wb`)
        dgt_params : dict
            DGT parameters
        signal_params : dict
            Signal parameters
        mask : nd-array
            Time-frequency binary mask
        x_loc : Waveform
            Localized source signal
        x_wb : Waveform
            Wideband source signal
        """
        dataset = get_dataset()
    
        x_loc = Waveform.from_wavfile(dataset['localized'][loc_source])
        x_wb = Waveform.from_wavfile(dataset['wideband'][wideband_src])
        np.testing.assert_array_equal(x_loc.shape, x_wb.shape)
        if crop is not None:
            x_len = crop
            i_start = (x_loc.shape[0] - x_len) // 2
            x_loc = x_loc[i_start:i_start+x_len]
            x_wb = x_wb[i_start:i_start+x_len]
        signal_params = get_signal_params(sig_len=x_loc.shape[0], fs=x_loc.fs)
    
        # Unit energy
        x_loc /= np.linalg.norm(x_loc)
        x_wb /= np.linalg.norm(x_wb)
        gain_wb = 1 / (1 + 10 ** (-wb_to_loc_ratio_db / 20))
        x_loc *= (1 - gain_wb)
        x_wb *= gain_wb
    
        # Build mix
        x_mix = x_loc + x_wb
    
        # Build dgt
        fs = x_loc.fs
        approx_win_len = int(2 ** np.round(np.log2(win_dur * fs)))
        hop = int(approx_win_len * hop_ratio)
        n_bins = int(approx_win_len * n_bins_ratio)
        sig_len = x_loc.shape[0]
        dgt_params = get_dgt_params(win_type=win_type,
                                    approx_win_len=approx_win_len,
                                    hop=hop, n_bins=n_bins, sig_len=sig_len)
    
        tf_mat_loc_db = db(np.abs(dgt(x_loc, dgt_params=dgt_params)))
        tf_mat_wb_db = db(np.abs(dgt(x_wb, dgt_params=dgt_params)))
    
        # Build mask_raw
        mask_mix = tf_mat_loc_db > tf_mat_wb_db + delta_mix_db
        mask_loc = tf_mat_loc_db > tf_mat_loc_db.max() - delta_loc_db
    
        if or_mask:
            mask_raw = np.logical_or(mask_mix, mask_loc)
        else:
            mask_raw = np.logical_and(mask_mix, mask_loc)
    
        struct = generate_binary_structure(2, 1)
        if n_iter_closing > 0:
            if closing_first:
                mask = binary_opening(
                    binary_closing(input=mask_raw, structure=struct,
                                   iterations=n_iter_closing, border_value=1),
                    iterations=n_iter_opening, structure=struct, border_value=0)
            else:
                mask = binary_closing(
                    binary_opening(input=mask_raw,structure=struct,
                                   iterations=n_iter_opening, border_value=0),
                    iterations=n_iter_closing, structure=struct, border_value=1)
        else:
            mask = mask_raw
    
    
        if fig_dir is not None:
            fig_dir = Path(fig_dir)
            fig_dir.mkdir(exist_ok=True, parents=True)
            if len(prefix) > 0:
                prefix = prefix + '_'
    
            plt.figure()
            plot_mask(mask=mask_mix, hop=dgt_params['hop'],
                      n_bins=dgt_params['n_bins'], fs=signal_params['fs'])
            plt.title('Mask Mix - Area: {} ({:.1%})'.format(mask_mix.sum(),
                                                            np.average(mask_mix)))
            plt.tight_layout()
            plt.savefig(fig_dir / 'mask_mix.pdf')
    
            plt.figure()
            plot_mask(mask=mask_loc, hop=dgt_params['hop'],
                      n_bins=dgt_params['n_bins'], fs=signal_params['fs'])
            plt.title('Mask Loc - Area: {} ({:.1%})'.format(mask_loc.sum(),
                                                            np.average(mask_loc)))
            plt.tight_layout()
            plt.savefig(fig_dir / 'mask_loc.pdf')
    
            plt.figure()
            plt.subplot(231)
            plot_spectrogram(x=x_mix, dgt_params=dgt_params, fs=fs)
            plt.title('Mix')
            plt.tight_layout()
            plt.subplot(232)
            plot_mask(mask=mask_raw, hop=dgt_params['hop'],
                      n_bins=dgt_params['n_bins'], fs=fs)
            plt.title('Raw mask_raw')
            plt.tight_layout()
            plt.subplot(233)
            plot_mask(mask=mask, hop=dgt_params['hop'],
                      n_bins=dgt_params['n_bins'], fs=fs)
            plt.tight_layout()
            plt.title('Smoothed mask_raw')
            plt.subplot(234)
            plot_spectrogram(x=x_loc, dgt_params=dgt_params, fs=fs)
            plt.title('Loc')
            plt.tight_layout()
            plt.subplot(235)
            tf_mat = dgt(x_loc, dgt_params=dgt_params) * mask
            plotdgtreal(coef=tf_mat, a=dgt_params['hop'], M=dgt_params['n_bins'],
                        fs=fs, dynrange=100)
            plt.title('Masked loc')
            plt.tight_layout()
            plt.subplot(236)
            gabmul = GaborMultiplier(mask=~mask,
                                     dgt_params=dgt_params,
                                     signal_params=signal_params)
            x_est = gabmul @ x_wb
            plot_spectrogram(x=x_est, dgt_params=dgt_params, fs=fs)
            plt.title('Filtered wb')
            plt.tight_layout()
            plt.savefig(fig_dir / '{}mix_spectro_mask.pdf'.format(prefix))
    
        return x_mix, dgt_params, signal_params, mask, x_loc, x_wb