Skip to content
Snippets Groups Projects
Commit af54a8ae authored by Florent Jaillet's avatar Florent Jaillet
Browse files

Correct get_known_mask() and finalize doctrings

parent a896393b
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -524,16 +524,29 @@ class MadArray(np.ndarray):
string = '<MadArray at {}>'
return string.format(hex(id(self)))
def get_known_mask(self, mask_type='any'):
def get_known_mask(self, mask_type='all'):
"""Boolean mask for known coefficients.
TODO: finish this
Compute the boolean mask marking known coefficients as True.
Parameters
----------
mask_type : {'any', 'all', 'magnitude', 'phase', 'magnitude only', \
mask_type : {'all', 'any', 'magnitude', 'phase', 'magnitude only', \
'phase only'}
TODO
Type of mask:
- ``all``: mark coefficients for wich both the magnitude and the
phase are known,
- ``any``: mark coefficients for wich the magnitude or the phase
are known (including when both the magnitude and the phase are
known),
- ``magnitude``: mark coefficients for wich the magnitude is
known,
- ``phase``: mark coefficients for wich the phase is known,
- ``magnitude only``: mark coefficients for wich both the magnitude
is known and the phase is unknown,
- ``phase only``: mark coefficients for wich both the phase is
known and the magnitude is unknown.
Returns
-------
......@@ -546,18 +559,45 @@ class MadArray(np.ndarray):
ValueError
If ``mask_type`` has an invalid value.
"""
return ~self.get_unknown_mask(mask_type)
if mask_type == 'all':
return ~self.get_unknown_mask('any')
elif mask_type == 'any':
return ~self.get_unknown_mask('all')
elif mask_type == 'magnitude':
return ~self.get_unknown_mask('magnitude')
elif mask_type == 'phase':
return ~self.get_unknown_mask('phase')
elif mask_type == 'magnitude only':
return self.get_unknown_mask('phase only')
elif mask_type == 'phase only':
return self.get_unknown_mask('magnitude only')
errmsg = 'Invalid value for mask_type: {}'.format(mask_type)
raise ValueError(errmsg)
def get_unknown_mask(self, mask_type='any'):
"""Boolean mask for unknown coefficients.
TODO: finish this
Compute the boolean mask marking unknown coefficients as True.
Parameters
----------
mask_type : {'any', 'all', 'magnitude', 'phase', 'magnitude only', \
'phase only'}
TODO
Type of mask:
- ``any``: mark coefficients for wich the magnitude or the phase
are unknown (including when both the magnitude and the phase are
unknown),
- ``all``: mark coefficients for wich both the magnitude and the
phase are unknown,
- ``magnitude``: mark coefficients for wich the magnitude is
unknown,
- ``phase``: mark coefficients for wich the phase is unknown,
- ``magnitude only``: mark coefficients for wich both the magnitude
is unknown and the phase is known,
- ``phase only``: mark coefficients for wich both the phase is
unknown and the magnitude is known.
Returns
-------
......
......@@ -138,11 +138,11 @@ class TestMadArray:
np.testing.assert_equal(ma.get_unknown_mask('phase only'),
self.empty_mask)
np.testing.assert_equal(ma.get_known_mask('phase only'),
self.full_mask)
self.empty_mask)
np.testing.assert_equal(ma.get_unknown_mask('magnitude only'),
self.empty_mask)
np.testing.assert_equal(ma.get_known_mask('magnitude only'),
self.full_mask)
self.empty_mask)
match = 'Invalid value for mask_type: error'
with pytest.raises(ValueError, match=match):
......@@ -183,22 +183,34 @@ class TestMadArray:
assert ma.is_masked()
np.testing.assert_equal(ma.to_np_array(), x)
np.testing.assert_equal(ma.get_known_mask(),
np.logical_and(~self.mm, ~self.mp))
np.testing.assert_equal(ma.get_unknown_mask(),
np.logical_or(self.mm, self.mp))
self.mm |self.mp)
np.testing.assert_equal(ma.get_known_mask(),
~self.mm & ~self.mp)
np.testing.assert_equal(ma.get_unknown_mask('any'),
np.logical_or(self.mm, self.mp))
self.mm | self.mp)
np.testing.assert_equal(ma.get_known_mask('any'),
~self.mm | ~self.mp)
np.testing.assert_equal(ma.get_unknown_mask('all'),
np.logical_and(self.mm, self.mp))
np.testing.assert_equal(ma.get_known_mask('phase'), ~self.mp)
np.testing.assert_equal(ma.get_unknown_mask('phase'), self.mp)
np.testing.assert_equal(ma.get_known_mask('magnitude'), ~self.mm)
np.testing.assert_equal(ma.get_unknown_mask('magnitude'), self.mm)
self.mm & self.mp)
np.testing.assert_equal(ma.get_known_mask('all'),
~self.mm & ~self.mp)
np.testing.assert_equal(ma.get_unknown_mask('phase'),
self.mp)
np.testing.assert_equal(ma.get_known_mask('phase'),
~self.mp)
np.testing.assert_equal(ma.get_unknown_mask('magnitude'),
self.mm)
np.testing.assert_equal(ma.get_known_mask('magnitude'),
~self.mm)
np.testing.assert_equal(ma.get_unknown_mask('phase only'),
np.logical_and(self.mp, ~self.mm))
self.mp & ~self.mm)
np.testing.assert_equal(ma.get_known_mask('phase only'),
~self.mp & self.mm)
np.testing.assert_equal(ma.get_unknown_mask('magnitude only'),
np.logical_and(~self.mp, self.mm))
self.mm & ~self.mp)
np.testing.assert_equal(ma.get_unknown_mask('phase only'),
~self.mm & self.mp)
match = 'Invalid value for mask_type: error'
with pytest.raises(ValueError, match=match):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment