[docs]class DeepMaskEstimation(MaskSeparationBase, DeepMixin):
"""Implements deep source separation models using PyTorch"""
def __init__(
self,
input_audio_signal,
model_path,
extra_modules=None,
mask_type='soft',
use_librosa_stft=False,
use_cuda=True,
):
super(DeepMaskEstimation, self).__init__(
input_audio_signal=input_audio_signal,
mask_type=mask_type
)
self.device = torch.device(
'cuda'
if torch.cuda.is_available() and use_cuda
else 'cpu'
)
self.model, self.metadata = self.load_model(model_path, extra_modules=extra_modules)
if input_audio_signal.sample_rate != self.metadata['sample_rate']:
input_audio_signal.resample(self.metadata['sample_rate'])
input_audio_signal.stft_params.window_length = self.metadata['n_fft']
input_audio_signal.stft_params.n_fft_bins = self.metadata['n_fft']
input_audio_signal.stft_params.hop_length = self.metadata['hop_length']
self.use_librosa_stft = use_librosa_stft
self._compute_spectrograms()
def _compute_spectrograms(self):
self.stft = self.audio_signal.stft(
overwrite=True,
remove_reflection=True,
use_librosa=self.use_librosa_stft
)
self.log_spectrogram = librosa.amplitude_to_db(
np.abs(self.stft),
ref=np.max
)
[docs] def run(self):
"""
Returns:
"""
input_data = self._preprocess()
with torch.no_grad():
output = self.model(input_data)
output = {k: output[k] for k in output}
if 'estimates' not in output:
raise ValueError("This model is not a mask estimation model!")
_masks = (output['estimates'] / input_data['magnitude_spectrogram'].unsqueeze(-1)).squeeze(0)
_masks = _masks.permute(3, 1, 0, 2)
_masks = _masks.cpu().data.numpy()
self.assignments = _masks
self.num_sources = self.assignments.shape[0]
self.masks = []
for i in range(self.assignments.shape[-1]):
mask = self.assignments[i, :, :, :]
mask = masks.SoftMask(mask)
if self.mask_type == self.BINARY_MASK:
mask = mask.mask_to_binary(1 / len(self.num_sources))
self.masks.append(mask)
return self.masks
[docs] def apply_mask(self, mask):
"""
Applies individual mask and returns audio_signal object
"""
source = copy.deepcopy(self.audio_signal)
source = source.apply_mask(mask)
source.stft_params = self.audio_signal.stft_params
source.istft(
overwrite=True,
truncate_to_length=self.audio_signal.signal_length
)
return source
[docs] def make_audio_signals(self):
""" Applies each mask in self.masks and returns a list of audio_signal
objects for each source.
Returns:
self.sources (np.array): An array of audio_signal objects
containing each separated source
"""
self.sources = []
for mask in self.masks:
self.sources.append(self.apply_mask(mask))
return self.sources