diff --git a/python/tffpy/datasets.py b/python/tffpy/datasets.py
index 13cac9fb553841826ec3ddcf858be7dee7bcb3cd..cd05466ba21506a5732b2f5f2d94ec4fb4a2bd6c 100644
--- a/python/tffpy/datasets.py
+++ b/python/tffpy/datasets.py
@@ -94,7 +94,7 @@ def get_mix(loc_source, wideband_src, crop=None,
     closing_first : bool
         If True, morphological closings are applied first, followed by
         openings. If False, the reverse way is used.
-    fig_dir : Path
+    fig_dir : str or Path
         If not None, folder where figures are stored. If None, figures are
         not plotted.
     prefix : str
diff --git a/python/tffpy/interpolation_solver.py b/python/tffpy/interpolation_solver.py
index a72e26db63f0ada52338a59c3da882037f51fd22..442255dbd05d64321b824931c3827500f72393dd 100644
--- a/python/tffpy/interpolation_solver.py
+++ b/python/tffpy/interpolation_solver.py
@@ -9,12 +9,33 @@ import matplotlib.pyplot as plt
 from pathlib import Path
 from ltfatpy import plotdgtreal
 
-from tffpy.datasets import get_mix
 from tffpy.utils import dgt, plot_spectrogram, plot_mask, idgt
 
 
 def solve_by_interpolation(x_mix, mask, dgt_params, signal_params,
                            fig_dir=None):
+    """
+    Time-frequency fading solver using linear interpolation and random phases
+
+    Parameters
+    ----------
+    x_mix : nd-array
+        Mix signal
+    mask : nd-array
+        Time-frequency mask
+    dgt_params : dict
+        DGT parameters
+    signal_params : dict
+        Signal parameters
+    fig_dir : str or Path
+        If not None, folder where figures are stored. If None, figures are
+        not plotted.
+
+    Returns
+    -------
+    nd-array
+        Estimated signal
+    """
     x_tf = dgt(sig=x_mix, dgt_params=dgt_params)
     mask = mask > 0
     x_tf[mask] = np.nan
@@ -31,49 +52,25 @@ def solve_by_interpolation(x_mix, mask, dgt_params, signal_params,
     x_est = idgt(tf_mat=x_tf, dgt_params=dgt_params,
                  sig_len=signal_params['sig_len'])
     if fig_dir is not None:
+        fig_dir = Path(fig_dir)
+        fig_dir.mkdir(exist_ok=True, parents=True)
+
         plt.figure()
         plot_mask(mask=mask, hop=dgt_params['hop'],
                   n_bins=dgt_params['n_bins'], fs=signal_params['fs'])
         plt.title('Masked observation')
+        plt.savefig(fig_dir / 'interp_mask.pdf')
 
         plt.figure()
         plotdgtreal(coef=x_tf, a=dgt_params['hop'],
                     M=dgt_params['n_bins'], fs=signal_params['fs'])
         plt.title('Interpolated TF matrix')
+        plt.savefig(fig_dir / 'interp_tf_est.pdf')
 
         plt.figure()
         plot_spectrogram(x=x_est, dgt_params=dgt_params,
                          fs=signal_params['fs'])
         plt.title('Reconstructed signal by interp')
+        plt.savefig(fig_dir / 'interp_sig_est.pdf')
 
     return x_est
-
-
-if __name__ == '__main__':
-    win_type = 'gauss'
-    win_dur = 256 / 8000
-    hop_ratio = 1 / 4
-    n_bins_ratio = 4
-    delta_mix_db = 0
-    delta_loc_db = 30
-    n_iter_closing = n_iter_opening = 3
-    wb_to_loc_ratio_db = 8
-    closing_first = True
-    or_mask = True
-
-
-    fig_dir = Path('fig_interpolation')
-    fig_dir.mkdir(parents=True, exist_ok=True)
-
-    x_mix, dgt_params, signal_params, mask, x_bird, x_engine = \
-        get_mix(loc_source='bird', wideband_src='car',
-                wb_to_loc_ratio_db=wb_to_loc_ratio_db,
-                win_dur=win_dur, win_type=win_type,
-                hop_ratio=hop_ratio, n_bins_ratio=n_bins_ratio,
-                n_iter_closing=n_iter_closing, n_iter_opening=n_iter_opening,
-                closing_first=closing_first,
-                delta_mix_db=delta_mix_db, delta_loc_db=delta_loc_db,
-                or_mask=or_mask, fig_dir=fig_dir)
-
-    x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params,
-                                   fig_dir)
diff --git a/python/tffpy/tests/test_interpolation_solver.py b/python/tffpy/tests/test_interpolation_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..46351271dcffa88fe2dc9052f860385185531a21
--- /dev/null
+++ b/python/tffpy/tests/test_interpolation_solver.py
@@ -0,0 +1,40 @@
+import unittest
+
+import numpy as np
+
+from tffpy.datasets import get_mix
+from tffpy.interpolation_solver import solve_by_interpolation
+
+
+class TestInterpolationSolver(unittest.TestCase):
+    def test_interpolation_solver(self):
+        win_type = 'gauss'
+        win_dur = 256 / 8000
+        hop_ratio = 1 / 4
+        n_bins_ratio = 4
+        delta_mix_db = 0
+        delta_loc_db = 30
+        n_iter_closing = n_iter_opening = 3
+        wb_to_loc_ratio_db = 8
+        closing_first = True
+        or_mask = True
+
+        fig_dir = 'test_fig_interpolation'
+
+        x_mix, dgt_params, signal_params, mask, x_bird, x_engine = \
+            get_mix(loc_source='bird', wideband_src='car', crop=4096,
+                    wb_to_loc_ratio_db=wb_to_loc_ratio_db,
+                    win_dur=win_dur, win_type=win_type,
+                    hop_ratio=hop_ratio, n_bins_ratio=n_bins_ratio,
+                    n_iter_closing=n_iter_closing,
+                    n_iter_opening=n_iter_opening,
+                    closing_first=closing_first,
+                    delta_mix_db=delta_mix_db, delta_loc_db=delta_loc_db,
+                    or_mask=or_mask, fig_dir=fig_dir)
+
+        x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params,
+                                       fig_dir)
+        np.testing.assert_array_equal(x_est.shape, x_mix.shape)
+
+        x_est = solve_by_interpolation(x_mix, mask, dgt_params, signal_params)
+        np.testing.assert_array_equal(x_est.shape, x_mix.shape)