diff --git a/ltfatpy/gabor/tfplot.py b/ltfatpy/gabor/tfplot.py
index 81a75882d7d9b01f8b99ef615ac64fd872e46f5f..04093f2698bd1d29d10d69ebbb241deb7c955d22 100644
--- a/ltfatpy/gabor/tfplot.py
+++ b/ltfatpy/gabor/tfplot.py
@@ -245,7 +245,8 @@ def tfplot(coef, step, yr, fs=None, dynrange=None, normalization='db',
             # plot_surface
             from mpl_toolkits.mplot3d import Axes3D
 
-            ax = plt.gca(projection='3d')
+            plt.delaxes()
+            ax = plt.gcf().add_subplot(111, projection='3d')
             ax.azim = -130.
             ax.elev = 30.
             xgrid, ygrid = np.meshgrid(xr, yr)