[docs]class MaskSeparationBase(separation_base.SeparationBase):
    """
    Base class for separation algorithms that create a mask (binary or soft) to do their separation. Most algorithms 
    in nussl are derived from :class:`MaskSeparationBase`.
    
    Although this class will do nothing if you instantiate and run it by itself, algorithms that are derived from this
    class are expected to return a list of :class:`separation.masks.mask_base.MaskBase` -derived objects 
    (i.e., either a :class:`separation.masks.binary_mask.BinaryMask` or :class:`separation.masks.soft_mask.SoftMask`
    object) by their :func:`run()` method. Being a subclass of :class:`MaskSeparationBase` is an implicit contract 
    assuring this.  Returning a :class:`separation.masks.mask_base.MaskBase`-derived object standardizes 
    algorithm return types for :class:`evaluation.evaluation_base.EvaluationBase`-derived objects. 
    
    Args:
        input_audio_signal: (:class:`audio_signal.AudioSignal`) An :class:`audio_signal.AudioSignal` object containing 
            the mixture to be separated.
        mask_type: (str) Indicates whether to make binary or soft masks. See :attr:`mask_type` property for details.
        mask_threshold: (float) Value between [0.0, 1.0] to convert a soft mask to a binary mask. See 
            :attr:`mask_threshold` property for details.
    
    """
    BINARY_MASK = 'binary'
    """ String alias for setting this object to return :class:`separation.masks.binary_mask.BinaryMask` objects
    """
    SOFT_MASK = 'soft'
    """ String alias for setting this object to return :class:`separation.masks.soft_mask.SoftMask` objects
    """
    _valid_mask_types = [BINARY_MASK, SOFT_MASK]
    def __init__(self, input_audio_signal, mask_type=SOFT_MASK, mask_threshold=0.5):
        super(MaskSeparationBase, self).__init__(input_audio_signal=input_audio_signal)
        self._mask_type = None
        self.mask_type = mask_type
        self._mask_threshold = None
        self.mask_threshold = mask_threshold
        self.result_masks = []
    @property
    def mask_type(self):
        """
        PROPERTY
        
        This property indicates what type of mask the derived algorithm will create and be returned by :func:`run()`.
        Options are either 'soft' or 'binary'. 
        :attr:`mask_type` is usually set when initializing a :class:`MaskSeparationBase`-derived class 
        and defaults to :attr:`SOFT_MASK`.
        
        This property, though stored as a string, can be set in two ways when initializing:
        
        * First, it is possible to set this property with a string. Only ``'soft'`` and ``'binary'`` are accepted 
          (case insensitive), every other value will raise an error. When initializing with a string, two helper 
          attributes are provided: :attr:`BINARY_MASK` and :attr:`SOFT_MASK`.
        
          It is **HIGHLY** encouraged to use these, as the API may change and code that uses bare strings 
          (e.g. ``mask_type = 'soft'`` or ``mask_type = 'binary'``) for assignment might not be future-proof. 
          :attr:`BINARY_MASK`` and :attr:`SOFT_MASK` are safe aliases in case these underlying types change.
        
        * The second way to set this property is by using a class prototype of either the 
          :class:`separation.masks.binary_mask.BinaryMask` or :class:`separation.masks.soft_mask.SoftMask` class
          prototype. This is probably the most stable way to set this, and it's fairly succinct. 
          For example, ``mask_type = nussl.BinaryMask`` or ``mask_type = nussl.SoftMask`` are both perfectly valid.
        
        Though uncommon, this can be set outside of :func:`__init__()` 
        
        Examples of both methods are shown below.
        
        Returns:
            mask_type (str): Either ``'soft'`` or ``'binary'``. 
            
        Raises:
            ValueError if set invalidly.
            
        Example:
            
        .. code-block:: python
            :linenos:
    
            import nussl
            mixture_signal = nussl.AudioSignal()
                
            # Two options for determining mask upon init...
             
            # Option 1: Init with a string (BINARY_MASK is a string 'constant')
            repet_sim = nussl.RepetSim(mixture_signal, mask_type=nussl.MaskSeparationBase.BINARY_MASK)
             
            # Option 2: Init with a class type
            ola = nussl.OverlapAdd(mixture_signal, mask_type=nussl.SoftMask)
             
            # It's also possible to change these values after init by changing the `mask_type` property...
            repet_sim.mask_type = nussl.MaskSeparationBase.SOFT_MASK  # using a string
            ola.mask_type = nussl.BinaryMask  # or using a class type
            
        """
        return self._mask_type
    @mask_type.setter
    def mask_type(self, value):
        error = ValueError(
            f"Invalid mask type! Got {value} but valid masks are:"
            f" [{', '.join(self._valid_mask_types)}]!"
        )
        if value is None:
            raise error
        if isinstance(value, str):
            value = value.lower()
            if value in self._valid_mask_types:
                self._mask_type = value
            else:
                raise error
        elif isinstance(value, masks.MaskBase):
            warnings.warn('This separation method is not using the values in the provided mask object.')
            value = type(value).__name__
            value = value[:value.find('Mask')].lower()
            if value not in self._valid_mask_types:
                # make sure we don't get duped by accident. This shouldn't happen
                raise error
            self._mask_type = value
        elif issubclass(value, masks.MaskBase):
            if value is masks.BinaryMask:
                self._mask_type = self.BINARY_MASK
            elif value is masks.SoftMask:
                self._mask_type = self.SOFT_MASK
            else:
                raise error
        else:
            raise error
    @property
    def mask_threshold(self):
        """
        PROPERTY
        
        Threshold of determining True/False if :attr:`mask_type` is :attr:`BINARY_MASK`. Some algorithms will first 
        make a soft mask and then convert that to a binary mask using this threshold parameter. All values of the 
        soft mask are between ``[0.0, 1.0]`` and as such :func:`mask_threshold` is expected to be a float between 
        ``[0.0, 1.0]``.
        
        Returns:
            mask_threshold (float): Value between ``[0.0, 1.0]`` that indicates the True/False cutoff when converting a
            soft mask to binary mask.
                                
        Raises:
            ValueError if not a float or if set outside ``[0.0, 1.0]``.
        """
        return self._mask_threshold
    @mask_threshold.setter
    def mask_threshold(self, value):
        if not isinstance(value, float) or not (0.0 < value < 1.0):
            raise ValueError('Mask threshold must be a float between [0.0, 1.0]!')
        self._mask_threshold = value
[docs]    def zeros_mask(self, shape):
        """
        Creates a new zeros mask with this object's type
        Args:
            shape:
        Returns:
        """
        if self.mask_type == self.BINARY_MASK:
            return masks.BinaryMask.zeros(shape)
        else:
            return masks.SoftMask.zeros(shape)
 
[docs]    def ones_mask(self, shape):
        """
        Args:
            shape:
        Returns:
        """
        if self.mask_type == self.BINARY_MASK:
            return masks.BinaryMask.ones(shape)
        else:
            return masks.SoftMask.ones(shape)
 
[docs]    def run(self):
        """Runs mask-based separation algorithm. Base class: Do not call directly!
        Raises:
            NotImplementedError: Cannot call base class!
        """
        raise NotImplementedError('Cannot call base class!')
 
[docs]    def make_audio_signals(self):
        """Makes :class:`audio_signal.AudioSignal` objects after mask-based separation algorithm is run. 
        Base class: Do not call directly!
        Raises:
            NotImplementedError: Cannot call base class!
        """
        raise NotImplementedError('Cannot call base class!')
 
[docs]    @classmethod
    def from_json(cls, json_string):
        """
        Creates a new :class:`SeparationBase` object from the parameters stored in this JSON string.
        Args:
            json_string (str): A JSON string containing all the data to create a new :class:`SeparationBase`
                object.
        Returns:
            (:class:`SeparationBase`) A new :class:`SeparationBase` object from the JSON string.
        See Also:
            :func:`to_json` to make a JSON string to freeze this object.
        """
        mask_sep_decoder = MaskSeparationBaseDecoder(cls)
        return mask_sep_decoder.decode(json_string)
  
class MaskSeparationBaseDecoder(separation_base.SeparationBaseDecoder):
    """ Object to decode a :class:`MaskSeparationBase`-derived object from JSON serialization.
    You should never have to instantiate this object by hand.
    """
    def __init__(self, separation_class):
        self.separation_class = separation_class
        json.JSONDecoder.__init__(self, object_hook=self._json_separation_decoder)
    def _json_separation_decoder(self, json_dict):
        if '__class__' in json_dict and '__module__' in json_dict:
            json_dict, separator = self._inspect_json_and_create_new_instance(json_dict)
            # fill out the rest of the fields
            for k, v in list(json_dict.items()):
                if isinstance(v, dict) and constants.NUMPY_JSON_KEY in v:
                    separator.__dict__[k] = utils.json_numpy_obj_hook(v[constants.NUMPY_JSON_KEY])
                    
                # TODO: test this in python3
                elif isinstance(v, (str, bytes)) and audio_signal.__name__ in v:
                    separator.__dict__[k] = audio_signal.AudioSignal.from_json(v)
                elif k == 'result_masks':
                    # for mask_json in v:
                    separator.result_masks = [masks.MaskBase.from_json(itm) for itm in v]
                else:
                    separator.__dict__[k] = v if not isinstance(v, str) else v.encode('ascii')
            return separator
        else:
            return json_dict