from .clustering_separation_base import ClusteringSeparationBase
import numpy as np
import torch
import librosa
from ..deep_mixin import DeepMixin
from .. import FT2D, Melodia, HPSS, Repet, RepetSim, MultichannelWienerFilter
from sklearn.decomposition import TruncatedSVD
from sklearn.preprocessing import scale
from copy import deepcopy
class SpatialClustering(ClusteringSeparationBase):
def extract_features(self):
ipd, ild = self.audio_signal.ipd_ild_features()
num_channels = self.audio_signal.num_channels
features = [
[ipd for i in range(num_channels)],
[ild for i in range(num_channels)]
]
features = np.array(features).transpose(2, 3, 1, 0)
features = features.reshape(-1, features.shape[-1])
return features
class PrimitiveClustering(ClusteringSeparationBase):
def __init__(
self,
input_audio_signal,
algorithms,
num_cascades=1,
scale_features=False,
**kwargs
):
super().__init__(
input_audio_signal,
**kwargs
)
self.original_stft_params = deepcopy(self.audio_signal.stft_params)
self.algorithms = [a[0] for a in algorithms]
self.algorithm_parameters = [a[1] if len(a) > 1 else {} for a in algorithms]
self.algorithm_returns = [a[2] if len(a) > 2 else [] for a in algorithms]
self.num_cascades = num_cascades
self.scale_features = scale_features
self.separators = self.setup_algorithms()
def setup_algorithms(self):
separators = []
mixture = deepcopy(self.audio_signal)
for i, algorithm in enumerate(self.algorithms):
stft_params = self.algorithm_parameters[i].pop('stft_params', None)
if stft_params is not None:
mixture.stft_data = None
mixture.stft_params = stft_params
separator = algorithm(
mixture,
use_librosa_stft=self.use_librosa_stft,
**self.algorithm_parameters[i]
)
mixture.stft_params = self.original_stft_params
separators.append(separator)
return separators
def set_audio_signal(self, new_audio_signal):
super().set_audio_signal(new_audio_signal)
self.setup_algorithms()
def run_algorithm_on_signal(self, mixture, level):
separations = []
for i, separator in enumerate(self.separators):
separator.run()
signals = separator.make_audio_signals()
if self.algorithm_returns[i]:
signals = [signals[j] for j in self.algorithm_returns[i]]
separations += signals
return separations, self.separators
def extract_features_from_signals(self, signals):
features = []
self.audio_signal.stft_data = None
self.audio_signal.stft_params = self.original_stft_params
mix_stft = np.abs(self.audio_signal.stft())
for s in signals:
s.stft_data = None
s.stft_params = self.original_stft_params
_stft = np.abs(s.stft())
_feature = _stft / np.maximum(_stft, mix_stft + 1e-7)
features.append(_feature)
features = np.array(features).transpose(1, 2, 3, 0)
return features
def extract_features_from_separators(self, separators):
features = []
for i, s in enumerate(separators):
masks = [m.mask for m in s.result_masks]
if self.algorithm_returns[i]:
masks = [masks[j] for j in self.algorithm_returns[i]]
features += masks
features = np.array(features).transpose(1, 2, 3, 0)
return features
def extract_features(self):
features = []
current_signals = [self.audio_signal]
separators = []
for i in range(self.num_cascades):
separations = []
for signal in current_signals:
_separations, _separator = self.run_algorithm_on_signal(signal, i)
separations += _separations
separators += _separator
current_signals = separations
self.separations = separations
features = self.extract_features_from_separators(separators)
self._compute_spectrograms()
features = features.reshape(-1, features.shape[-1])
if self.scale_features:
features = scale(features, axis=0)
#features = scale(features, axis=0)
return features
[docs]class DeepClustering(ClusteringSeparationBase, DeepMixin):
def __init__(
self,
input_audio_signal,
model_path,
metadata=None,
extra_modules=None,
use_cuda=False,
**kwargs
):
self.device = torch.device(
'cuda'
if torch.cuda.is_available() and use_cuda
else 'cpu'
)
self.metadata = metadata
self.model, self.metadata = self.load_model(model_path, extra_modules=extra_modules)
input_audio_signal = self.set_audio_signal(input_audio_signal)
sample_rate = self.metadata['sample_rate']
num_mels = self.model.layers['mel_projection'].num_mels
num_frequencies = (self.metadata['n_fft'] // 2) + 1
filter_bank = None
if num_mels > 0:
weights = self.model.layers['mel_projection'].transform.weight.data.cpu().numpy()
filter_bank = np.linalg.pinv(weights.T)
self.filter_bank = filter_bank
super().__init__(input_audio_signal, **kwargs)
[docs] def set_audio_signal(self, new_audio_signal):
input_audio_signal = deepcopy(new_audio_signal)
if input_audio_signal.sample_rate != self.metadata['sample_rate']:
input_audio_signal.resample(self.metadata['sample_rate'])
input_audio_signal.stft_params.window_length = self.metadata['n_fft']
input_audio_signal.stft_params.n_fft_bins = self.metadata['n_fft']
input_audio_signal.stft_params.hop_length = self.metadata['hop_length']
input_audio_signal = super().set_audio_signal(input_audio_signal)
return input_audio_signal
[docs] def postprocess(self, assignments, confidence):
if self.filter_bank is not None:
shape = (self.filter_bank.shape[0], -1, self.stft.shape[-1])
assignments = assignments.reshape(shape + (self.num_sources,))
confidence = confidence.reshape(shape)
assignments = assignments.transpose()
confidence = confidence.transpose()
assignments = np.dot(assignments, self.filter_bank) + 1e-6
assignments = np.clip(assignments, 0.0, 1.0)
assignments /= np.sum(assignments, axis=0)
confidence = np.dot(confidence, self.filter_bank)
confidence += np.abs(confidence.min())
assignments = assignments.transpose()
confidence = confidence.transpose()
assignments = assignments.transpose(3, 0, 1, 2)
else:
assignments, confidence = super().postprocess(assignments, confidence)
return assignments, confidence
[docs] def project_data(self, data):
if self.model.layers['mel_projection'].num_mels > 0:
data = self._format(data, 'rnn')
data = torch.from_numpy(data).to(self.device).float()
data = self.model.project_data(data, clamp=False)
data = data.squeeze(-1).permute(2, 1, 0)
data = (data > 0).cpu().data.numpy().astype(bool)
return data
[docs] def make_audio_signals(self):
signals = super().make_audio_signals()
residual = (self.audio_signal - sum(signals)).audio_data * (1 / len(signals))
for signal in signals:
signal.audio_data += residual
if signal.sample_rate != self.original_sample_rate:
signal.resample(self.original_sample_rate)
signal.truncate_samples(self.original_length)
return signals