diff --git a/madarrays/mad_array.py b/madarrays/mad_array.py index e9478f98e9bfeed5b4faab956f6ef149a0c054ee..c18200cf58b53bb1401065aa571189efcc026234 100644 --- a/madarrays/mad_array.py +++ b/madarrays/mad_array.py @@ -524,16 +524,29 @@ class MadArray(np.ndarray): string = '<MadArray at {}>' return string.format(hex(id(self))) - def get_known_mask(self, mask_type='any'): + def get_known_mask(self, mask_type='all'): """Boolean mask for known coefficients. - TODO: finish this + Compute the boolean mask marking known coefficients as True. Parameters ---------- - mask_type : {'any', 'all', 'magnitude', 'phase', 'magnitude only', \ + mask_type : {'all', 'any', 'magnitude', 'phase', 'magnitude only', \ 'phase only'} - TODO + Type of mask: + + - ``all``: mark coefficients for wich both the magnitude and the + phase are known, + - ``any``: mark coefficients for wich the magnitude or the phase + are known (including when both the magnitude and the phase are + known), + - ``magnitude``: mark coefficients for wich the magnitude is + known, + - ``phase``: mark coefficients for wich the phase is known, + - ``magnitude only``: mark coefficients for wich both the magnitude + is known and the phase is unknown, + - ``phase only``: mark coefficients for wich both the phase is + known and the magnitude is unknown. Returns ------- @@ -546,18 +559,45 @@ class MadArray(np.ndarray): ValueError If ``mask_type`` has an invalid value. """ - return ~self.get_unknown_mask(mask_type) + if mask_type == 'all': + return ~self.get_unknown_mask('any') + elif mask_type == 'any': + return ~self.get_unknown_mask('all') + elif mask_type == 'magnitude': + return ~self.get_unknown_mask('magnitude') + elif mask_type == 'phase': + return ~self.get_unknown_mask('phase') + elif mask_type == 'magnitude only': + return self.get_unknown_mask('phase only') + elif mask_type == 'phase only': + return self.get_unknown_mask('magnitude only') + + errmsg = 'Invalid value for mask_type: {}'.format(mask_type) + raise ValueError(errmsg) def get_unknown_mask(self, mask_type='any'): """Boolean mask for unknown coefficients. - TODO: finish this + Compute the boolean mask marking unknown coefficients as True. Parameters ---------- mask_type : {'any', 'all', 'magnitude', 'phase', 'magnitude only', \ 'phase only'} - TODO + Type of mask: + + - ``any``: mark coefficients for wich the magnitude or the phase + are unknown (including when both the magnitude and the phase are + unknown), + - ``all``: mark coefficients for wich both the magnitude and the + phase are unknown, + - ``magnitude``: mark coefficients for wich the magnitude is + unknown, + - ``phase``: mark coefficients for wich the phase is unknown, + - ``magnitude only``: mark coefficients for wich both the magnitude + is unknown and the phase is known, + - ``phase only``: mark coefficients for wich both the phase is + unknown and the magnitude is known. Returns ------- diff --git a/madarrays/tests/test_madarray.py b/madarrays/tests/test_madarray.py index f700816c7a88e99158400f0a8e1703b25b35ba85..87d07d5c9fb648a85e0f41930267991379c52ada 100644 --- a/madarrays/tests/test_madarray.py +++ b/madarrays/tests/test_madarray.py @@ -138,11 +138,11 @@ class TestMadArray: np.testing.assert_equal(ma.get_unknown_mask('phase only'), self.empty_mask) np.testing.assert_equal(ma.get_known_mask('phase only'), - self.full_mask) + self.empty_mask) np.testing.assert_equal(ma.get_unknown_mask('magnitude only'), self.empty_mask) np.testing.assert_equal(ma.get_known_mask('magnitude only'), - self.full_mask) + self.empty_mask) match = 'Invalid value for mask_type: error' with pytest.raises(ValueError, match=match): @@ -183,22 +183,34 @@ class TestMadArray: assert ma.is_masked() np.testing.assert_equal(ma.to_np_array(), x) - np.testing.assert_equal(ma.get_known_mask(), - np.logical_and(~self.mm, ~self.mp)) np.testing.assert_equal(ma.get_unknown_mask(), - np.logical_or(self.mm, self.mp)) + self.mm |self.mp) + np.testing.assert_equal(ma.get_known_mask(), + ~self.mm & ~self.mp) np.testing.assert_equal(ma.get_unknown_mask('any'), - np.logical_or(self.mm, self.mp)) + self.mm | self.mp) + np.testing.assert_equal(ma.get_known_mask('any'), + ~self.mm | ~self.mp) np.testing.assert_equal(ma.get_unknown_mask('all'), - np.logical_and(self.mm, self.mp)) - np.testing.assert_equal(ma.get_known_mask('phase'), ~self.mp) - np.testing.assert_equal(ma.get_unknown_mask('phase'), self.mp) - np.testing.assert_equal(ma.get_known_mask('magnitude'), ~self.mm) - np.testing.assert_equal(ma.get_unknown_mask('magnitude'), self.mm) + self.mm & self.mp) + np.testing.assert_equal(ma.get_known_mask('all'), + ~self.mm & ~self.mp) + np.testing.assert_equal(ma.get_unknown_mask('phase'), + self.mp) + np.testing.assert_equal(ma.get_known_mask('phase'), + ~self.mp) + np.testing.assert_equal(ma.get_unknown_mask('magnitude'), + self.mm) + np.testing.assert_equal(ma.get_known_mask('magnitude'), + ~self.mm) np.testing.assert_equal(ma.get_unknown_mask('phase only'), - np.logical_and(self.mp, ~self.mm)) + self.mp & ~self.mm) + np.testing.assert_equal(ma.get_known_mask('phase only'), + ~self.mp & self.mm) np.testing.assert_equal(ma.get_unknown_mask('magnitude only'), - np.logical_and(~self.mp, self.mm)) + self.mm & ~self.mp) + np.testing.assert_equal(ma.get_unknown_mask('phase only'), + ~self.mm & self.mp) match = 'Invalid value for mask_type: error' with pytest.raises(ValueError, match=match):