Skip to content
Snippets Groups Projects
Commit 8d8c7866 authored by Florent Jaillet's avatar Florent Jaillet
Browse files

Change behavior of MadArray constructor when passing a MadArray and a mask

parent 8e57c9b7
Branches
Tags
No related merge requests found
Pipeline #
...@@ -248,11 +248,9 @@ class MadArray(np.ndarray): ...@@ -248,11 +248,9 @@ class MadArray(np.ndarray):
_data = _data.astype(np.complex) _data = _data.astype(np.complex)
if mask_magnitude is None: if mask_magnitude is None:
if isinstance(data, MadArray): if isinstance(data, MadArray) and mask_phase is None:
if data._complex_masking: mask_magnitude = data.unknown_magnitude_mask
mask_magnitude = data.unknown_magnitude_mask mask_phase = data.unknown_phase_mask
else:
mask_magnitude = data.unknown_mask
else: else:
mask_magnitude = np.zeros_like(data, dtype=np.bool) mask_magnitude = np.zeros_like(data, dtype=np.bool)
else: else:
...@@ -264,13 +262,7 @@ class MadArray(np.ndarray): ...@@ -264,13 +262,7 @@ class MadArray(np.ndarray):
_data.shape)) _data.shape))
if mask_phase is None: if mask_phase is None:
if isinstance(data, MadArray): mask_phase = np.zeros_like(data, dtype=np.bool)
if data._complex_masking:
mask_phase = data.unknown_phase_mask
else:
mask_phase = data.unknown_mask
else:
mask_phase = np.zeros_like(data, dtype=np.bool)
else: else:
mask_phase = np.array(mask_phase, dtype=np.bool) mask_phase = np.array(mask_phase, dtype=np.bool)
......
...@@ -313,8 +313,7 @@ class TestMadArray: ...@@ -313,8 +313,7 @@ class TestMadArray:
assert ma._complex_masking assert ma._complex_masking
old_ma = MadArray(x, mask_phase=self.mp, mask_magnitude=self.mm) old_ma = MadArray(x, mask_phase=self.mp, mask_magnitude=self.mm)
ma = MadArray(old_ma, mask_phase=self.mm, ma = MadArray(old_ma, mask_phase=self.mm, mask_magnitude=self.mp)
mask_magnitude=self.mp)
np.testing.assert_equal(ma.unknown_phase_mask, self.mm) np.testing.assert_equal(ma.unknown_phase_mask, self.mm)
np.testing.assert_equal(ma.unknown_magnitude_mask, self.mp) np.testing.assert_equal(ma.unknown_magnitude_mask, self.mp)
...@@ -325,14 +324,16 @@ class TestMadArray: ...@@ -325,14 +324,16 @@ class TestMadArray:
ma = MadArray(old_ma, mask_phase=self.mm) ma = MadArray(old_ma, mask_phase=self.mm)
np.testing.assert_equal(ma.unknown_phase_mask, self.mm) np.testing.assert_equal(ma.unknown_phase_mask, self.mm)
np.testing.assert_equal(ma.unknown_magnitude_mask, self.mm) np.testing.assert_equal(ma.unknown_magnitude_mask,
np.zeros_like(x, dtype=np.bool))
assert id(old_ma) != id(ma) assert id(old_ma) != id(ma)
assert id(old_ma._mask) != id(ma._mask) assert id(old_ma._mask) != id(ma._mask)
old_ma = MadArray(x, mask_phase=self.mp, mask_magnitude=self.mm) old_ma = MadArray(x, mask_phase=self.mp, mask_magnitude=self.mm)
ma = MadArray(old_ma, mask_magnitude=self.mp) ma = MadArray(old_ma, mask_magnitude=self.mp)
np.testing.assert_equal(ma.unknown_phase_mask, self.mp) np.testing.assert_equal(ma.unknown_phase_mask,
np.zeros_like(x, dtype=np.bool))
np.testing.assert_equal(ma.unknown_magnitude_mask, self.mp) np.testing.assert_equal(ma.unknown_magnitude_mask, self.mp)
assert id(old_ma) != id(ma) assert id(old_ma) != id(ma)
assert id(old_ma._mask) != id(ma._mask) assert id(old_ma._mask) != id(ma._mask)
...@@ -350,7 +351,8 @@ class TestMadArray: ...@@ -350,7 +351,8 @@ class TestMadArray:
ma = MadArray(old_ma, mask_phase=self.mp) ma = MadArray(old_ma, mask_phase=self.mp)
np.testing.assert_equal(ma.unknown_phase_mask, self.mp) np.testing.assert_equal(ma.unknown_phase_mask, self.mp)
np.testing.assert_equal(ma.unknown_magnitude_mask, self.m) np.testing.assert_equal(ma.unknown_magnitude_mask,
np.zeros_like(x, dtype=np.bool))
assert id(old_ma) != id(ma) assert id(old_ma) != id(ma)
assert id(old_ma._mask) != id(ma._mask) assert id(old_ma._mask) != id(ma._mask)
assert ma._complex_masking assert ma._complex_masking
...@@ -358,7 +360,8 @@ class TestMadArray: ...@@ -358,7 +360,8 @@ class TestMadArray:
old_ma = MadArray(x, self.m) old_ma = MadArray(x, self.m)
ma = MadArray(old_ma, mask_magnitude=self.mm) ma = MadArray(old_ma, mask_magnitude=self.mm)
np.testing.assert_equal(ma.unknown_phase_mask, self.m) np.testing.assert_equal(ma.unknown_phase_mask,
np.zeros_like(x, dtype=np.bool))
np.testing.assert_equal(ma.unknown_magnitude_mask, self.mm) np.testing.assert_equal(ma.unknown_magnitude_mask, self.mm)
assert id(old_ma) != id(ma) assert id(old_ma) != id(ma)
assert id(old_ma._mask) != id(ma._mask) assert id(old_ma._mask) != id(ma._mask)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment