Commit 877c8698 authored by Florent Jaillet's avatar Florent Jaillet

Merge branch 'master' of gitlab.lis-lab.fr:skmad-suite/madarrays

parents 49e0c430 8d588f6c
Pipeline #476 passed with stages
in 46 seconds
...@@ -340,7 +340,8 @@ class MadArray(np.ndarray): ...@@ -340,7 +340,8 @@ class MadArray(np.ndarray):
new_results = [] new_results = []
for result, output in zip(results, outputs): for result, output in zip(results, outputs):
if output is None: if output is None:
if ufunc.__name__ in UFUNC_RETURNING_MADARRAYS: if (method == '__call__' and
ufunc.__name__ in UFUNC_RETURNING_MADARRAYS):
new_results.append(np.asarray(result).view(MadArray)) new_results.append(np.asarray(result).view(MadArray))
new_results[-1]._mask = mask new_results[-1]._mask = mask
new_results[-1]._complex_masking = complex_masking new_results[-1]._complex_masking = complex_masking
......
...@@ -450,8 +450,7 @@ class TestMadArray: ...@@ -450,8 +450,7 @@ class TestMadArray:
assert id(ma) != id(ma_copy) assert id(ma) != id(ma_copy)
assert id(ma._mask) != id(ma_copy._mask) assert id(ma._mask) != id(ma_copy._mask)
def test_operations(self): def test_basic_operations(self):
match = 'Operation not permitted when complex masking.' match = 'Operation not permitted when complex masking.'
for operator in ['+', '-', '*', '/', '//']: for operator in ['+', '-', '*', '/', '//']:
print('Operation: {}'.format(operator)) print('Operation: {}'.format(operator))
...@@ -506,6 +505,24 @@ class TestMadArray: ...@@ -506,6 +505,24 @@ class TestMadArray:
with pytest.raises(ValueError, match=match): with pytest.raises(ValueError, match=match):
ma + ma2 ma + ma2
def test_advanced_operations(self):
for x in [self.x_float, self.x_int, self.x_complex]:
ma = MadArray(x, self.m)
m = np.mean(ma)
assert not isinstance(m, MadArray)
np.testing.assert_equal(m, np.mean(x))
m = np.mean(ma, axis=0)
assert not isinstance(m, MadArray)
np.testing.assert_equal(m, np.mean(x, axis=0))
m = np.std(ma)
assert not isinstance(m, MadArray)
np.testing.assert_equal(m, np.std(x))
def test_eq_ne_numpy(self): def test_eq_ne_numpy(self):
ma = MadArray(self.x_float) ma = MadArray(self.x_float)
......
...@@ -970,7 +970,7 @@ class TestWaveform: ...@@ -970,7 +970,7 @@ class TestWaveform:
assert cmp_w.dtype == np.bool assert cmp_w.dtype == np.bool
assert np.all(~cmp_w) assert np.all(~cmp_w)
def test_operations(self): def test_basic_operations(self):
x = self.x_mono x = self.x_mono
...@@ -995,6 +995,24 @@ class TestWaveform: ...@@ -995,6 +995,24 @@ class TestWaveform:
assert isinstance(ws, Waveform) assert isinstance(ws, Waveform)
np.testing.assert_equal(ws, xs) np.testing.assert_equal(ws, xs)
def test_advanced_operations(self):
x = self.x_mono
ma = MadArray(x, self.m_mono)
m = np.mean(ma)
assert not isinstance(m, MadArray)
np.testing.assert_equal(m, np.mean(x))
m = np.mean(ma, axis=0)
assert not isinstance(m, MadArray)
np.testing.assert_equal(m, np.mean(x, axis=0))
m = np.std(ma)
assert not isinstance(m, MadArray)
np.testing.assert_equal(m, np.std(x))
def test_eq(self): def test_eq(self):
x = self.x_mono x = self.x_mono
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment