diff --git a/madarrays/waveform.py b/madarrays/waveform.py index 59613d74f3eb990c8166795197a647187d33d850..b3426abc8a46a4373b95d50ef55d3ba7e702f243 100644 --- a/madarrays/waveform.py +++ b/madarrays/waveform.py @@ -60,6 +60,14 @@ from .mad_array import MadArray VALID_IO_FS = {1, 8000, 16000, 32000, 48000, 11025, 22050, 44100, 88200} +def _check_compatibility_fs(w1, w2): + """Raise an exception if the sampling frequency of the two Waveforms + are different.""" + if w1.fs != w2.fs: + errmsg = 'Waveforms do not have the same fs: {} and {}.' + raise ValueError(errmsg.format(w1.fs, w2.fs)) + + class Waveform(MadArray): """Subclass of MadArray to handle mono and stereo audio signals. @@ -158,16 +166,14 @@ class Waveform(MadArray): def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + if len(inputs) == 2: + if (isinstance(inputs[0], Waveform) and + isinstance(inputs[1], Waveform)): + _check_compatibility_fs(*inputs) + output = super().__array_ufunc__(ufunc, method, *inputs, **kwargs) if isinstance(output, MadArray): - if len(inputs) == 2: - if (isinstance(inputs[0], Waveform) and - isinstance(inputs[1], Waveform) and - inputs[0].fs != inputs[1].fs): - errmsg = 'Waveforms do not have the same fs: {} and {}' - raise ValueError(errmsg.format(inputs[0].fs, inputs[1].fs)) - output = output.view(Waveform) output.fs = inputs[0].fs if isinstance(inputs[0], Waveform) \ else inputs[1].fs @@ -773,8 +779,9 @@ class Waveform(MadArray): return Waveform(self) def __eq__(self, other): - # Test the compatability between Waveform - _ = self + other + if isinstance(other, Waveform): + _check_compatibility_fs(self, other) + return super().__eq__(other) def is_equal(self, other):