diff --git a/madarrays/mad_array.py b/madarrays/mad_array.py index 0993007081b5a93d2de37615751dda2d04ef1ad3..bd0c73c1a0ced8f0e374d8e4257e6f034f9b6b66 100644 --- a/madarrays/mad_array.py +++ b/madarrays/mad_array.py @@ -292,6 +292,80 @@ class MadArray(np.ndarray): def __array_wrap__(self, obj, context=None): return obj[()] if obj.shape == () else obj + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + + args = [] + is_mad = [] + for input_ in inputs: + if isinstance(input_, MadArray): + args.append(input_.view(np.ndarray)) + is_mad.append(True) + else: + args.append(input_) + is_mad.append(False) + + if len(is_mad) > 1: + if np.all(is_mad): + if inputs[0]._complex_masking or inputs[1]._complex_masking: + errmsg = 'Operation not permitted when complex masking.' + raise ValueError(errmsg) + + mask = MadArray(inputs[0].data, ** + _merge_masks(inputs[0], inputs[1]))._mask + complex_masking = inputs[0]._complex_masking + masked_indexing = inputs[0]._masked_indexing + else: + where_is_mad = np.argwhere(is_mad).squeeze() + mask = inputs[where_is_mad]._mask + complex_masking = inputs[where_is_mad]._complex_masking + masked_indexing = inputs[where_is_mad]._masked_indexing + + else: + mask = inputs[0]._mask + complex_masking = inputs[0]._complex_masking + masked_indexing = inputs[0]._masked_indexing + + outputs = kwargs.pop('out', None) + if outputs: + out_args = [] + for output in outputs: + if isinstance(output, MadArray): + out_args.append(output.view(np.ndarray)) + else: + out_args.append(output) + kwargs['out'] = tuple(out_args) + else: + outputs = (None,) * ufunc.nout + + results = super().__array_ufunc__(ufunc, method, *args, **kwargs) + + if results is NotImplemented: + return NotImplemented + + if method == 'at': + return + + if ufunc.nout == 1: + results = (results,) + + new_results = [] + for result, output in zip(results, outputs): + if output is None: + if not ufunc.__name__.startswith('is'): + new_results.append(np.asarray(result).view(MadArray)) + new_results[-1]._mask = mask + new_results[-1]._complex_masking = complex_masking + new_results[-1]._masked_indexing = masked_indexing + else: + new_results.append(np.asarray(result).view(np.ndarray)) + else: + new_results.append(output) + results = tuple(new_results) + + return results[0] if len(results) == 1 else results + + # return MadArray(results, **mask_arg, + # masked_indexing=inputs[0]._masked_indexing) def __getitem__(self, index): if (getattr(self, '_masked_indexing', None) is not None and @@ -582,61 +656,6 @@ class MadArray(np.ndarray): data[self.unknown_mask] = fill_value return data - def __add__(self, other): - if isinstance(other, MadArray): - if self._complex_masking or other._complex_masking: - errmsg = 'Operation not permitted when complex masking.' - raise ValueError(errmsg) - - return MadArray(np.add(self.to_np_array(), - other.to_np_array()), - **_merge_masks(self, other), - masked_indexing=self._masked_indexing) - else: - return super().__add__(other) - - def __sub__(self, other): - return self.__add__(-other) - - def __mul__(self, other): - if isinstance(other, MadArray): - if self._complex_masking or other._complex_masking: - errmsg = 'Operation not permitted when complex masking.' - raise ValueError(errmsg) - - return MadArray(np.multiply(self.to_np_array(), - other.to_np_array()), - **_merge_masks(self, other), - masked_indexing=self._masked_indexing) - else: - return super().__mul__(other) - - def __truediv__(self, other): - if isinstance(other, MadArray): - if self._complex_masking or other._complex_masking: - errmsg = 'Operation not permitted when complex masking.' - raise ValueError(errmsg) - - return MadArray(np.true_divide(self.to_np_array(), - other.to_np_array()), - **_merge_masks(self, other), - masked_indexing=self._masked_indexing) - else: - return super().__truediv__(other) - - def __floordiv__(self, other): - if isinstance(other, MadArray): - if self._complex_masking or other._complex_masking: - errmsg = 'Operation not permitted when complex masking.' - raise ValueError(errmsg) - - return MadArray(np.floor_divide(self.to_np_array(), - other.to_np_array()), - **_merge_masks(self, other), - masked_indexing=self._masked_indexing) - else: - return super().__floordiv__(other) - def __eq__(self, other): if isinstance(other, MadArray): return np.logical_and(self.to_np_array(0) == other.to_np_array(0), diff --git a/madarrays/tests/test_madarray.py b/madarrays/tests/test_madarray.py index 490764352dbc2c74180e6b780df4a876feec3ba3..75cf063a0ae45bae3789f38f3c8f5487a4af7ce9 100644 --- a/madarrays/tests/test_madarray.py +++ b/madarrays/tests/test_madarray.py @@ -455,6 +455,7 @@ class TestMadArray: match = 'Operation not permitted when complex masking.' for operator in ['+', '-', '*', '/', '//']: print('Operation: {}'.format(operator)) + for x in [self.x_float, self.x_int, self.x_complex]: ma = MadArray(x, self.m) diff --git a/madarrays/tests/test_waveform.py b/madarrays/tests/test_waveform.py index 24d7bb974616f40d66e20e4f1738b08015ac1f01..83302d55416643fb5af1b1f6328ced7721f502dc 100644 --- a/madarrays/tests/test_waveform.py +++ b/madarrays/tests/test_waveform.py @@ -970,17 +970,20 @@ class TestWaveform: def test_operations(self): - w1 = Waveform(self.x_mono, fs=self.fs) - w2 = Waveform(self.x_mono, fs=self.fs) - w3 = Waveform(self.x_mono, fs=1 if self.fs > 1 else 44100) - ma = MadArray(self.x_mono) + x = self.x_mono - for operator in ['+', '-', '*', '/', '//']: + w1 = Waveform(x, fs=self.fs) + w2 = Waveform(x, fs=self.fs) + w3 = Waveform(x, fs=1 if self.fs > 1 else 44100) + ma = MadArray(x) + + for operator in ['-', '-', '*', '/', '//']: print('Operation: {}'.format(operator)) + xs = eval('x {} x'.format(operator)) ws = eval('w1 {} w2'.format(operator)) assert isinstance(ws, Waveform) - np.testing.assert_equal(ws, self.x_mono + self.x_mono) + np.testing.assert_equal(ws, xs) match='Waveforms do not have the same fs: \d+ and \d+' with pytest.raises(ValueError, match=match): @@ -988,13 +991,13 @@ class TestWaveform: ms = eval('w1 {} ma'.format(operator)) assert isinstance(ws, Waveform) - np.testing.assert_equal(ws, self.x_mono + self.x_mono) + np.testing.assert_equal(ws, xs) with pytest.raises(ValueError, match=match): - w1 == w2 + w1 == w3 - assert not w1.is_equal(w2) + assert not w1.is_equal(w3) def test_fade(self): all_modes = {'both', 'in', 'out'} diff --git a/madarrays/waveform.py b/madarrays/waveform.py index c323293bfc65ec61e90cb9fadb6d4019187b77df..08ad2ef04d308d9433c0584339917f515d71902f 100644 --- a/madarrays/waveform.py +++ b/madarrays/waveform.py @@ -156,6 +156,23 @@ class Waveform(MadArray): return obj + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + + output = super().__array_ufunc__(ufunc, method, *inputs, **kwargs) + + if isinstance(output, MadArray): + if len(inputs) == 2: + if (isinstance(inputs[0], Waveform) and + isinstance(inputs[1], Waveform) and + inputs[0].fs != inputs[1].fs): + errmsg = 'Waveforms do not have the same fs: {} and {}' + raise ValueError(errmsg.format(inputs[0].fs, inputs[1].fs)) + + output = output.view(Waveform) + output.fs = inputs[0].fs if isinstance(inputs[0], Waveform) else inputs[1].fs + + return output + def __array_finalize__(self, obj): super().__array_finalize__(obj) self._fs = getattr(obj, '_fs', 1) @@ -754,37 +771,16 @@ class Waveform(MadArray): def copy(self): return Waveform(self) - def __add__(self, other): + # def __add__(self, other): if isinstance(other, Waveform) and self.fs != other.fs: errmsg = 'Waveforms do not have the same fs: {} and {}' raise ValueError(errmsg.format(self.fs, other.fs)) - return super().__add__(other) + return Waveform(super().__add__(other), fs=self.fs) def __sub__(self, ma): return self.__add__(-ma) - def __mul__(self, other): - if isinstance(other, Waveform) and self.fs != other.fs: - errmsg = 'Waveforms do not have the same fs: {} and {}' - raise ValueError(errmsg.format(self.fs, other.fs)) - - return super().__mul__(other) - - def __truediv__(self, other): - if isinstance(other, Waveform) and self.fs != other.fs: - errmsg = 'Waveforms do not have the same fs: {} and {}' - raise ValueError(errmsg.format(self.fs, other.fs)) - - return super().__truediv__(other) - - def __floordiv__(self, other): - if isinstance(other, Waveform) and self.fs != other.fs: - errmsg = 'Waveforms do not have the same fs: {} and {}' - raise ValueError(errmsg.format(self.fs, other.fs)) - - return super().__floordiv__(other) - def __eq__(self, other): if isinstance(other, Waveform) and self.fs != other.fs: errmsg = 'Waveforms do not have the same fs: {} and {}'