diff --git a/madarrays/mad_array.py b/madarrays/mad_array.py index deba53d6f0d19989e4a88006c91ec1f93fa0a582..9fcfac603bf8c9b2031e843c345efa231f0f16cf 100644 --- a/madarrays/mad_array.py +++ b/madarrays/mad_array.py @@ -248,11 +248,9 @@ class MadArray(np.ndarray): _data = _data.astype(np.complex) if mask_magnitude is None: - if isinstance(data, MadArray): - if data._complex_masking: - mask_magnitude = data.unknown_magnitude_mask - else: - mask_magnitude = data.unknown_mask + if isinstance(data, MadArray) and mask_phase is None: + mask_magnitude = data.unknown_magnitude_mask + mask_phase = data.unknown_phase_mask else: mask_magnitude = np.zeros_like(data, dtype=np.bool) else: @@ -264,13 +262,7 @@ class MadArray(np.ndarray): _data.shape)) if mask_phase is None: - if isinstance(data, MadArray): - if data._complex_masking: - mask_phase = data.unknown_phase_mask - else: - mask_phase = data.unknown_mask - else: - mask_phase = np.zeros_like(data, dtype=np.bool) + mask_phase = np.zeros_like(data, dtype=np.bool) else: mask_phase = np.array(mask_phase, dtype=np.bool) diff --git a/madarrays/tests/test_madarray.py b/madarrays/tests/test_madarray.py index 046febed51028b3d6639647547dd736a2a9bc98d..fdf0fff117ee2781dbf0dca6a2fb1f88b5046a52 100644 --- a/madarrays/tests/test_madarray.py +++ b/madarrays/tests/test_madarray.py @@ -313,8 +313,7 @@ class TestMadArray: assert ma._complex_masking old_ma = MadArray(x, mask_phase=self.mp, mask_magnitude=self.mm) - ma = MadArray(old_ma, mask_phase=self.mm, - mask_magnitude=self.mp) + ma = MadArray(old_ma, mask_phase=self.mm, mask_magnitude=self.mp) np.testing.assert_equal(ma.unknown_phase_mask, self.mm) np.testing.assert_equal(ma.unknown_magnitude_mask, self.mp) @@ -325,14 +324,16 @@ class TestMadArray: ma = MadArray(old_ma, mask_phase=self.mm) np.testing.assert_equal(ma.unknown_phase_mask, self.mm) - np.testing.assert_equal(ma.unknown_magnitude_mask, self.mm) + np.testing.assert_equal(ma.unknown_magnitude_mask, + np.zeros_like(x, dtype=np.bool)) assert id(old_ma) != id(ma) assert id(old_ma._mask) != id(ma._mask) old_ma = MadArray(x, mask_phase=self.mp, mask_magnitude=self.mm) ma = MadArray(old_ma, mask_magnitude=self.mp) - np.testing.assert_equal(ma.unknown_phase_mask, self.mp) + np.testing.assert_equal(ma.unknown_phase_mask, + np.zeros_like(x, dtype=np.bool)) np.testing.assert_equal(ma.unknown_magnitude_mask, self.mp) assert id(old_ma) != id(ma) assert id(old_ma._mask) != id(ma._mask) @@ -350,7 +351,8 @@ class TestMadArray: ma = MadArray(old_ma, mask_phase=self.mp) np.testing.assert_equal(ma.unknown_phase_mask, self.mp) - np.testing.assert_equal(ma.unknown_magnitude_mask, self.m) + np.testing.assert_equal(ma.unknown_magnitude_mask, + np.zeros_like(x, dtype=np.bool)) assert id(old_ma) != id(ma) assert id(old_ma._mask) != id(ma._mask) assert ma._complex_masking @@ -358,7 +360,8 @@ class TestMadArray: old_ma = MadArray(x, self.m) ma = MadArray(old_ma, mask_magnitude=self.mm) - np.testing.assert_equal(ma.unknown_phase_mask, self.m) + np.testing.assert_equal(ma.unknown_phase_mask, + np.zeros_like(x, dtype=np.bool)) np.testing.assert_equal(ma.unknown_magnitude_mask, self.mm) assert id(old_ma) != id(ma) assert id(old_ma._mask) != id(ma._mask)