From 8d588f6cc70eba85326017ee8bad0a5a10a223d4 Mon Sep 17 00:00:00 2001 From: Ronan Hamon <ronan.hamon@lis-lab.fr> Date: Thu, 26 Apr 2018 16:03:31 +0200 Subject: [PATCH] fix type output when using mean on madarrays --- madarrays/mad_array.py | 3 ++- madarrays/tests/test_madarray.py | 21 +++++++++++++++++++-- madarrays/tests/test_waveform.py | 20 +++++++++++++++++++- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/madarrays/mad_array.py b/madarrays/mad_array.py index 98c02e1..208eb27 100644 --- a/madarrays/mad_array.py +++ b/madarrays/mad_array.py @@ -340,7 +340,8 @@ class MadArray(np.ndarray): new_results = [] for result, output in zip(results, outputs): if output is None: - if ufunc.__name__ in UFUNC_RETURNING_MADARRAYS: + if (method == '__call__' and + ufunc.__name__ in UFUNC_RETURNING_MADARRAYS): new_results.append(np.asarray(result).view(MadArray)) new_results[-1]._mask = mask new_results[-1]._complex_masking = complex_masking diff --git a/madarrays/tests/test_madarray.py b/madarrays/tests/test_madarray.py index 75cf063..ff08855 100644 --- a/madarrays/tests/test_madarray.py +++ b/madarrays/tests/test_madarray.py @@ -450,8 +450,7 @@ class TestMadArray: assert id(ma) != id(ma_copy) assert id(ma._mask) != id(ma_copy._mask) - def test_operations(self): - + def test_basic_operations(self): match = 'Operation not permitted when complex masking.' for operator in ['+', '-', '*', '/', '//']: print('Operation: {}'.format(operator)) @@ -506,6 +505,24 @@ class TestMadArray: with pytest.raises(ValueError, match=match): ma + ma2 + def test_advanced_operations(self): + + for x in [self.x_float, self.x_int, self.x_complex]: + + ma = MadArray(x, self.m) + + m = np.mean(ma) + assert not isinstance(m, MadArray) + np.testing.assert_equal(m, np.mean(x)) + + m = np.mean(ma, axis=0) + assert not isinstance(m, MadArray) + np.testing.assert_equal(m, np.mean(x, axis=0)) + + m = np.std(ma) + assert not isinstance(m, MadArray) + np.testing.assert_equal(m, np.std(x)) + def test_eq_ne_numpy(self): ma = MadArray(self.x_float) diff --git a/madarrays/tests/test_waveform.py b/madarrays/tests/test_waveform.py index 1e565d5..2e0fb83 100644 --- a/madarrays/tests/test_waveform.py +++ b/madarrays/tests/test_waveform.py @@ -970,7 +970,7 @@ class TestWaveform: assert cmp_w.dtype == np.bool assert np.all(~cmp_w) - def test_operations(self): + def test_basic_operations(self): x = self.x_mono @@ -995,6 +995,24 @@ class TestWaveform: assert isinstance(ws, Waveform) np.testing.assert_equal(ws, xs) + def test_advanced_operations(self): + + x = self.x_mono + + ma = MadArray(x, self.m_mono) + + m = np.mean(ma) + assert not isinstance(m, MadArray) + np.testing.assert_equal(m, np.mean(x)) + + m = np.mean(ma, axis=0) + assert not isinstance(m, MadArray) + np.testing.assert_equal(m, np.mean(x, axis=0)) + + m = np.std(ma) + assert not isinstance(m, MadArray) + np.testing.assert_equal(m, np.std(x)) + def test_eq(self): x = self.x_mono -- GitLab