From 8d588f6cc70eba85326017ee8bad0a5a10a223d4 Mon Sep 17 00:00:00 2001
From: Ronan Hamon <ronan.hamon@lis-lab.fr>
Date: Thu, 26 Apr 2018 16:03:31 +0200
Subject: [PATCH] fix type output when using mean on madarrays

---
 madarrays/mad_array.py           |  3 ++-
 madarrays/tests/test_madarray.py | 21 +++++++++++++++++++--
 madarrays/tests/test_waveform.py | 20 +++++++++++++++++++-
 3 files changed, 40 insertions(+), 4 deletions(-)

diff --git a/madarrays/mad_array.py b/madarrays/mad_array.py
index 98c02e1..208eb27 100644
--- a/madarrays/mad_array.py
+++ b/madarrays/mad_array.py
@@ -340,7 +340,8 @@ class MadArray(np.ndarray):
         new_results = []
         for result, output in zip(results, outputs):
             if output is None:
-                if ufunc.__name__ in UFUNC_RETURNING_MADARRAYS:
+                if (method == '__call__' and
+                        ufunc.__name__ in UFUNC_RETURNING_MADARRAYS):
                     new_results.append(np.asarray(result).view(MadArray))
                     new_results[-1]._mask = mask
                     new_results[-1]._complex_masking = complex_masking
diff --git a/madarrays/tests/test_madarray.py b/madarrays/tests/test_madarray.py
index 75cf063..ff08855 100644
--- a/madarrays/tests/test_madarray.py
+++ b/madarrays/tests/test_madarray.py
@@ -450,8 +450,7 @@ class TestMadArray:
         assert id(ma) != id(ma_copy)
         assert id(ma._mask) != id(ma_copy._mask)
 
-    def test_operations(self):
-
+    def test_basic_operations(self):
         match = 'Operation not permitted when complex masking.'
         for operator in ['+', '-', '*', '/', '//']:
             print('Operation: {}'.format(operator))
@@ -506,6 +505,24 @@ class TestMadArray:
                     with pytest.raises(ValueError, match=match):
                         ma + ma2
 
+    def test_advanced_operations(self):
+
+        for x in [self.x_float, self.x_int, self.x_complex]:
+
+            ma = MadArray(x, self.m)
+
+            m = np.mean(ma)
+            assert not isinstance(m, MadArray)
+            np.testing.assert_equal(m, np.mean(x))
+
+            m = np.mean(ma, axis=0)
+            assert not isinstance(m, MadArray)
+            np.testing.assert_equal(m, np.mean(x, axis=0))
+
+            m = np.std(ma)
+            assert not isinstance(m, MadArray)
+            np.testing.assert_equal(m, np.std(x))
+
     def test_eq_ne_numpy(self):
 
         ma = MadArray(self.x_float)
diff --git a/madarrays/tests/test_waveform.py b/madarrays/tests/test_waveform.py
index 1e565d5..2e0fb83 100644
--- a/madarrays/tests/test_waveform.py
+++ b/madarrays/tests/test_waveform.py
@@ -970,7 +970,7 @@ class TestWaveform:
         assert cmp_w.dtype == np.bool
         assert np.all(~cmp_w)
 
-    def test_operations(self):
+    def test_basic_operations(self):
 
         x = self.x_mono
 
@@ -995,6 +995,24 @@ class TestWaveform:
             assert isinstance(ws, Waveform)
             np.testing.assert_equal(ws, xs)
 
+    def test_advanced_operations(self):
+
+        x = self.x_mono
+
+        ma = MadArray(x, self.m_mono)
+
+        m = np.mean(ma)
+        assert not isinstance(m, MadArray)
+        np.testing.assert_equal(m, np.mean(x))
+
+        m = np.mean(ma, axis=0)
+        assert not isinstance(m, MadArray)
+        np.testing.assert_equal(m, np.mean(x, axis=0))
+
+        m = np.std(ma)
+        assert not isinstance(m, MadArray)
+        np.testing.assert_equal(m, np.std(x))
+
     def test_eq(self):
 
         x = self.x_mono
-- 
GitLab