diff --git a/madarrays/mad_array.py b/madarrays/mad_array.py index 98c02e1ccbfe35bfe6bf8adf93b3b7d0f8b65818..208eb27465a94057e316684404182b016d032223 100644 --- a/madarrays/mad_array.py +++ b/madarrays/mad_array.py @@ -340,7 +340,8 @@ class MadArray(np.ndarray): new_results = [] for result, output in zip(results, outputs): if output is None: - if ufunc.__name__ in UFUNC_RETURNING_MADARRAYS: + if (method == '__call__' and + ufunc.__name__ in UFUNC_RETURNING_MADARRAYS): new_results.append(np.asarray(result).view(MadArray)) new_results[-1]._mask = mask new_results[-1]._complex_masking = complex_masking diff --git a/madarrays/tests/test_madarray.py b/madarrays/tests/test_madarray.py index 75cf063a0ae45bae3789f38f3c8f5487a4af7ce9..ff08855a878676221e9d0f8d7aec7b8b9f82e9f0 100644 --- a/madarrays/tests/test_madarray.py +++ b/madarrays/tests/test_madarray.py @@ -450,8 +450,7 @@ class TestMadArray: assert id(ma) != id(ma_copy) assert id(ma._mask) != id(ma_copy._mask) - def test_operations(self): - + def test_basic_operations(self): match = 'Operation not permitted when complex masking.' for operator in ['+', '-', '*', '/', '//']: print('Operation: {}'.format(operator)) @@ -506,6 +505,24 @@ class TestMadArray: with pytest.raises(ValueError, match=match): ma + ma2 + def test_advanced_operations(self): + + for x in [self.x_float, self.x_int, self.x_complex]: + + ma = MadArray(x, self.m) + + m = np.mean(ma) + assert not isinstance(m, MadArray) + np.testing.assert_equal(m, np.mean(x)) + + m = np.mean(ma, axis=0) + assert not isinstance(m, MadArray) + np.testing.assert_equal(m, np.mean(x, axis=0)) + + m = np.std(ma) + assert not isinstance(m, MadArray) + np.testing.assert_equal(m, np.std(x)) + def test_eq_ne_numpy(self): ma = MadArray(self.x_float) diff --git a/madarrays/tests/test_waveform.py b/madarrays/tests/test_waveform.py index 1e565d559d8c714175deaf34d7cb720636111da4..2e0fb8387f8b9ee3bd8b359dbafa7dde711de373 100644 --- a/madarrays/tests/test_waveform.py +++ b/madarrays/tests/test_waveform.py @@ -970,7 +970,7 @@ class TestWaveform: assert cmp_w.dtype == np.bool assert np.all(~cmp_w) - def test_operations(self): + def test_basic_operations(self): x = self.x_mono @@ -995,6 +995,24 @@ class TestWaveform: assert isinstance(ws, Waveform) np.testing.assert_equal(ws, xs) + def test_advanced_operations(self): + + x = self.x_mono + + ma = MadArray(x, self.m_mono) + + m = np.mean(ma) + assert not isinstance(m, MadArray) + np.testing.assert_equal(m, np.mean(x)) + + m = np.mean(ma, axis=0) + assert not isinstance(m, MadArray) + np.testing.assert_equal(m, np.mean(x, axis=0)) + + m = np.std(ma) + assert not isinstance(m, MadArray) + np.testing.assert_equal(m, np.std(x)) + def test_eq(self): x = self.x_mono