Skip to content
Snippets Groups Projects
Commit 0b09a453 authored by Ronan Hamon's avatar Ronan Hamon
Browse files

Update list of ufunc that should not return a MadArray

parent 11bda000
No related branches found
No related tags found
No related merge requests found
Pipeline #
...@@ -98,8 +98,15 @@ def _complex_masking_only(f): ...@@ -98,8 +98,15 @@ def _complex_masking_only(f):
return decorated return decorated
UFUNC_RETURNING_MADARRAYS = ['add', 'subtract', 'multiply', 'true_divide', UFUNC_NOT_RETURNING_MADARRAYS = ['bitwise_and', 'bitwise_or', 'bitwise_xor',
'floor_divide', 'floor', 'ceil', 'absolute'] 'invert', 'left_shift', 'right_shift',
'greater', 'greater_equal', 'less',
'less_equal', 'not_equal', 'equal',
'logical_and', 'logical_or', 'logical_xor',
'logical_not', 'maximum', 'minimum', 'fmax',
'fmin', 'isfinite', 'isinf', 'isnan', 'isnat',
'signbit', 'copysign', 'nextafter', 'spacing',
'modf', 'frexp', 'fmod']
class MadArray(np.ndarray): class MadArray(np.ndarray):
...@@ -341,7 +348,7 @@ class MadArray(np.ndarray): ...@@ -341,7 +348,7 @@ class MadArray(np.ndarray):
for result, output in zip(results, outputs): for result, output in zip(results, outputs):
if output is None: if output is None:
if (method == '__call__' and if (method == '__call__' and
ufunc.__name__ in UFUNC_RETURNING_MADARRAYS): ufunc.__name__ not in UFUNC_NOT_RETURNING_MADARRAYS):
new_results.append(np.asarray(result).view(MadArray)) new_results.append(np.asarray(result).view(MadArray))
new_results[-1]._mask = mask new_results[-1]._mask = mask
new_results[-1]._complex_masking = complex_masking new_results[-1]._complex_masking = complex_masking
......
...@@ -523,6 +523,39 @@ class TestMadArray: ...@@ -523,6 +523,39 @@ class TestMadArray:
assert not isinstance(m, MadArray) assert not isinstance(m, MadArray)
np.testing.assert_equal(m, np.std(x)) np.testing.assert_equal(m, np.std(x))
m = np.abs(ma)
assert isinstance(m, MadArray)
np.testing.assert_equal(m, np.abs(x))
m = np.sqrt(ma)
assert isinstance(m, MadArray)
np.testing.assert_equal(m, np.sqrt(x))
m = ma**2
assert isinstance(m, MadArray)
np.testing.assert_equal(m, x**2)
m = np.conj(ma)
assert isinstance(m, MadArray)
np.testing.assert_equal(m, np.conj(x))
if np.issubdtype(ma.dtype, np.floating):
m = np.floor(ma)
assert isinstance(m, MadArray)
np.testing.assert_equal(m, np.floor(x))
m = np.ceil(ma)
assert isinstance(m, MadArray)
np.testing.assert_equal(m, np.ceil(x))
a = ma < np.mean(ma)
assert not isinstance(a, MadArray)
np.testing.assert_equal(a, x < np.mean(ma))
a = ma >= np.mean(ma)
assert not isinstance(a, MadArray)
np.testing.assert_equal(a, x >= np.mean(ma))
def test_eq_ne_numpy(self): def test_eq_ne_numpy(self):
ma = MadArray(self.x_float) ma = MadArray(self.x_float)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment