Skip to content
Snippets Groups Projects
Commit 2c1ee178 authored by Ronan Hamon's avatar Ronan Hamon
Browse files

fix handling of ufunc in MadArray and Waveform (issue #1)

parent 7a918170
Branches
Tags
No related merge requests found
Pipeline #
......@@ -292,6 +292,80 @@ class MadArray(np.ndarray):
def __array_wrap__(self, obj, context=None):
return obj[()] if obj.shape == () else obj
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
args = []
is_mad = []
for input_ in inputs:
if isinstance(input_, MadArray):
args.append(input_.view(np.ndarray))
is_mad.append(True)
else:
args.append(input_)
is_mad.append(False)
if len(is_mad) > 1:
if np.all(is_mad):
if inputs[0]._complex_masking or inputs[1]._complex_masking:
errmsg = 'Operation not permitted when complex masking.'
raise ValueError(errmsg)
mask = MadArray(inputs[0].data, **
_merge_masks(inputs[0], inputs[1]))._mask
complex_masking = inputs[0]._complex_masking
masked_indexing = inputs[0]._masked_indexing
else:
where_is_mad = np.argwhere(is_mad).squeeze()
mask = inputs[where_is_mad]._mask
complex_masking = inputs[where_is_mad]._complex_masking
masked_indexing = inputs[where_is_mad]._masked_indexing
else:
mask = inputs[0]._mask
complex_masking = inputs[0]._complex_masking
masked_indexing = inputs[0]._masked_indexing
outputs = kwargs.pop('out', None)
if outputs:
out_args = []
for output in outputs:
if isinstance(output, MadArray):
out_args.append(output.view(np.ndarray))
else:
out_args.append(output)
kwargs['out'] = tuple(out_args)
else:
outputs = (None,) * ufunc.nout
results = super().__array_ufunc__(ufunc, method, *args, **kwargs)
if results is NotImplemented:
return NotImplemented
if method == 'at':
return
if ufunc.nout == 1:
results = (results,)
new_results = []
for result, output in zip(results, outputs):
if output is None:
if not ufunc.__name__.startswith('is'):
new_results.append(np.asarray(result).view(MadArray))
new_results[-1]._mask = mask
new_results[-1]._complex_masking = complex_masking
new_results[-1]._masked_indexing = masked_indexing
else:
new_results.append(np.asarray(result).view(np.ndarray))
else:
new_results.append(output)
results = tuple(new_results)
return results[0] if len(results) == 1 else results
# return MadArray(results, **mask_arg,
# masked_indexing=inputs[0]._masked_indexing)
def __getitem__(self, index):
if (getattr(self, '_masked_indexing', None) is not None and
......@@ -582,61 +656,6 @@ class MadArray(np.ndarray):
data[self.unknown_mask] = fill_value
return data
def __add__(self, other):
if isinstance(other, MadArray):
if self._complex_masking or other._complex_masking:
errmsg = 'Operation not permitted when complex masking.'
raise ValueError(errmsg)
return MadArray(np.add(self.to_np_array(),
other.to_np_array()),
**_merge_masks(self, other),
masked_indexing=self._masked_indexing)
else:
return super().__add__(other)
def __sub__(self, other):
return self.__add__(-other)
def __mul__(self, other):
if isinstance(other, MadArray):
if self._complex_masking or other._complex_masking:
errmsg = 'Operation not permitted when complex masking.'
raise ValueError(errmsg)
return MadArray(np.multiply(self.to_np_array(),
other.to_np_array()),
**_merge_masks(self, other),
masked_indexing=self._masked_indexing)
else:
return super().__mul__(other)
def __truediv__(self, other):
if isinstance(other, MadArray):
if self._complex_masking or other._complex_masking:
errmsg = 'Operation not permitted when complex masking.'
raise ValueError(errmsg)
return MadArray(np.true_divide(self.to_np_array(),
other.to_np_array()),
**_merge_masks(self, other),
masked_indexing=self._masked_indexing)
else:
return super().__truediv__(other)
def __floordiv__(self, other):
if isinstance(other, MadArray):
if self._complex_masking or other._complex_masking:
errmsg = 'Operation not permitted when complex masking.'
raise ValueError(errmsg)
return MadArray(np.floor_divide(self.to_np_array(),
other.to_np_array()),
**_merge_masks(self, other),
masked_indexing=self._masked_indexing)
else:
return super().__floordiv__(other)
def __eq__(self, other):
if isinstance(other, MadArray):
return np.logical_and(self.to_np_array(0) == other.to_np_array(0),
......
......@@ -455,6 +455,7 @@ class TestMadArray:
match = 'Operation not permitted when complex masking.'
for operator in ['+', '-', '*', '/', '//']:
print('Operation: {}'.format(operator))
for x in [self.x_float, self.x_int, self.x_complex]:
ma = MadArray(x, self.m)
......
......@@ -970,17 +970,20 @@ class TestWaveform:
def test_operations(self):
w1 = Waveform(self.x_mono, fs=self.fs)
w2 = Waveform(self.x_mono, fs=self.fs)
w3 = Waveform(self.x_mono, fs=1 if self.fs > 1 else 44100)
ma = MadArray(self.x_mono)
x = self.x_mono
for operator in ['+', '-', '*', '/', '//']:
w1 = Waveform(x, fs=self.fs)
w2 = Waveform(x, fs=self.fs)
w3 = Waveform(x, fs=1 if self.fs > 1 else 44100)
ma = MadArray(x)
for operator in ['-', '-', '*', '/', '//']:
print('Operation: {}'.format(operator))
xs = eval('x {} x'.format(operator))
ws = eval('w1 {} w2'.format(operator))
assert isinstance(ws, Waveform)
np.testing.assert_equal(ws, self.x_mono + self.x_mono)
np.testing.assert_equal(ws, xs)
match='Waveforms do not have the same fs: \d+ and \d+'
with pytest.raises(ValueError, match=match):
......@@ -988,13 +991,13 @@ class TestWaveform:
ms = eval('w1 {} ma'.format(operator))
assert isinstance(ws, Waveform)
np.testing.assert_equal(ws, self.x_mono + self.x_mono)
np.testing.assert_equal(ws, xs)
with pytest.raises(ValueError, match=match):
w1 == w2
w1 == w3
assert not w1.is_equal(w2)
assert not w1.is_equal(w3)
def test_fade(self):
all_modes = {'both', 'in', 'out'}
......
......@@ -156,6 +156,23 @@ class Waveform(MadArray):
return obj
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
output = super().__array_ufunc__(ufunc, method, *inputs, **kwargs)
if isinstance(output, MadArray):
if len(inputs) == 2:
if (isinstance(inputs[0], Waveform) and
isinstance(inputs[1], Waveform) and
inputs[0].fs != inputs[1].fs):
errmsg = 'Waveforms do not have the same fs: {} and {}'
raise ValueError(errmsg.format(inputs[0].fs, inputs[1].fs))
output = output.view(Waveform)
output.fs = inputs[0].fs if isinstance(inputs[0], Waveform) else inputs[1].fs
return output
def __array_finalize__(self, obj):
super().__array_finalize__(obj)
self._fs = getattr(obj, '_fs', 1)
......@@ -754,37 +771,16 @@ class Waveform(MadArray):
def copy(self):
return Waveform(self)
def __add__(self, other):
# def __add__(self, other):
if isinstance(other, Waveform) and self.fs != other.fs:
errmsg = 'Waveforms do not have the same fs: {} and {}'
raise ValueError(errmsg.format(self.fs, other.fs))
return super().__add__(other)
return Waveform(super().__add__(other), fs=self.fs)
def __sub__(self, ma):
return self.__add__(-ma)
def __mul__(self, other):
if isinstance(other, Waveform) and self.fs != other.fs:
errmsg = 'Waveforms do not have the same fs: {} and {}'
raise ValueError(errmsg.format(self.fs, other.fs))
return super().__mul__(other)
def __truediv__(self, other):
if isinstance(other, Waveform) and self.fs != other.fs:
errmsg = 'Waveforms do not have the same fs: {} and {}'
raise ValueError(errmsg.format(self.fs, other.fs))
return super().__truediv__(other)
def __floordiv__(self, other):
if isinstance(other, Waveform) and self.fs != other.fs:
errmsg = 'Waveforms do not have the same fs: {} and {}'
raise ValueError(errmsg.format(self.fs, other.fs))
return super().__floordiv__(other)
def __eq__(self, other):
if isinstance(other, Waveform) and self.fs != other.fs:
errmsg = 'Waveforms do not have the same fs: {} and {}'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment