Commit 0b09a453 authored by Ronan Hamon's avatar Ronan Hamon

Update list of ufunc that should not return a MadArray

parent 11bda000
Pipeline #485 passed with stages
in 1 minute and 1 second
......@@ -98,8 +98,15 @@ def _complex_masking_only(f):
return decorated
UFUNC_RETURNING_MADARRAYS = ['add', 'subtract', 'multiply', 'true_divide',
'floor_divide', 'floor', 'ceil', 'absolute']
UFUNC_NOT_RETURNING_MADARRAYS = ['bitwise_and', 'bitwise_or', 'bitwise_xor',
'invert', 'left_shift', 'right_shift',
'greater', 'greater_equal', 'less',
'less_equal', 'not_equal', 'equal',
'logical_and', 'logical_or', 'logical_xor',
'logical_not', 'maximum', 'minimum', 'fmax',
'fmin', 'isfinite', 'isinf', 'isnan', 'isnat',
'signbit', 'copysign', 'nextafter', 'spacing',
'modf', 'frexp', 'fmod']
class MadArray(np.ndarray):
......@@ -341,7 +348,7 @@ class MadArray(np.ndarray):
for result, output in zip(results, outputs):
if output is None:
if (method == '__call__' and
ufunc.__name__ in UFUNC_RETURNING_MADARRAYS):
ufunc.__name__ not in UFUNC_NOT_RETURNING_MADARRAYS):
new_results.append(np.asarray(result).view(MadArray))
new_results[-1]._mask = mask
new_results[-1]._complex_masking = complex_masking
......
......@@ -523,6 +523,39 @@ class TestMadArray:
assert not isinstance(m, MadArray)
np.testing.assert_equal(m, np.std(x))
m = np.abs(ma)
assert isinstance(m, MadArray)
np.testing.assert_equal(m, np.abs(x))
m = np.sqrt(ma)
assert isinstance(m, MadArray)
np.testing.assert_equal(m, np.sqrt(x))
m = ma**2
assert isinstance(m, MadArray)
np.testing.assert_equal(m, x**2)
m = np.conj(ma)
assert isinstance(m, MadArray)
np.testing.assert_equal(m, np.conj(x))
if np.issubdtype(ma.dtype, np.floating):
m = np.floor(ma)
assert isinstance(m, MadArray)
np.testing.assert_equal(m, np.floor(x))
m = np.ceil(ma)
assert isinstance(m, MadArray)
np.testing.assert_equal(m, np.ceil(x))
a = ma < np.mean(ma)
assert not isinstance(a, MadArray)
np.testing.assert_equal(a, x < np.mean(ma))
a = ma >= np.mean(ma)
assert not isinstance(a, MadArray)
np.testing.assert_equal(a, x >= np.mean(ma))
def test_eq_ne_numpy(self):
ma = MadArray(self.x_float)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment