[docs]class SeparationModel(nn.Module):
def __init__(self, config, extra_modules=None):
"""
SeparationModel takes a configuration file or dictionary that describes the model
structure, which is some combination of MelProjection, Embedding, RecurrentStack,
ConvolutionalStack, and other modules found in networks.modules. The configuration file
can be built using the helper functions in config.builders:
- build_dpcl_config: Builds the original deep clustering network, mapping each
time-frequency point to an embedding of some size. Takes as input a
log_spectrogram.
- build_mi_config: Builds a "traditional" mask inference network that maps the mixture
spectrogram to source estimates. Takes as input a log_spectrogram and a
magnitude_spectrogram.
- build_chimera_config: Builds a Chimera network with a mask inference head and a
deep clustering head to map. A combination of MI and DPCL. Takes as input a
log_spectrogram and a magnitude_spectrogram.
References:
Hershey, J. R., Chen, Z., Le Roux, J., & Watanabe, S. (2016, March).
Deep clustering: Discriminative embeddings for segmentation and separation.
In Acoustics, Speech and Signal Processing (ICASSP),
2016 IEEE International Conference on (pp. 31-35). IEEE.
Luo, Y., Chen, Z., Hershey, J. R., Le Roux, J., & Mesgarani, N. (2017, March).
Deep clustering and conventional networks for music separation: Stronger together.
In Acoustics, Speech and Signal Processing (ICASSP),
2017 IEEE International Conference on (pp. 61-65). IEEE.
Args:
config: (str, dict) Either a config dictionary built from one of the helper functions,
or the path to a json file containing a config built from the helper functions.
Examples:
>>> args = {
>>> 'num_frequencies': 512,
>>> 'num_mels': 128,
>>> 'sample_rate': 44100,
>>> 'hidden_size': 300,
>>> 'bidirectional': True,
>>> 'num_layers': 4,
>>> 'embedding_size': 20,
>>> 'num_sources': 4,
>>> 'embedding_activation': ['sigmoid', 'unitnorm'],
>>> 'mask_activation': ['softmax']
>>> }
>>> config = helpers.build_chimera_config(args)
>>> with open('config.json', 'w') as f:
>>> json.dump(config, f)
>>> model = SeparationModel('config.json')
>>> test_data = np.random.random((1, 100, 512))
>>> data = torch.from_numpy(test_data).float()
>>> output = model({'log_spectrogram': data,
>>> 'magnitude_spectrogram': data})
"""
super(SeparationModel, self).__init__()
if type(config) is str:
if 'json' in config:
with open(config, 'r') as f:
config = json.load(f)
else:
config = json.loads(config)
# Add extra modules to modules
if extra_modules:
for name in dir(extra_modules):
if name not in dir(modules):
setattr(
modules,
name,
getattr(extra_modules, name)
)
module_dict = {}
self.input = {}
for module_key in config['modules']:
module = config['modules'][module_key]
if 'input_shape' not in module:
class_func = getattr(modules, module['class'])
module_dict[module_key] = class_func(**module['args'])
else:
self.input[module_key] = module['input_shape']
self.layers = nn.ModuleDict(module_dict)
self.connections = config['connections']
self.output_keys = config['output']
self.config = config
[docs] def forward(self, data):
"""
Args:
data: (dict) a dictionary containing the input data for the model. Should match the input_keys
in self.input.
Returns:
"""
if not all(name in list(data) for name in list(self.input)):
raise ValueError(f'Not all keys present in data! Needs {", ".join(self.input)}')
output = {}
for connection in self.connections:
layer = self.layers[connection[0]]
input_data = []
for c in connection[1]:
input_data.append(output[c] if c in output else data[c])
_output = layer(*input_data)
if isinstance(_output, dict):
for k in _output:
output[f'{connection[0]}:{k}'] = _output[k]
else:
output[connection[0]] = _output
return {o: output[o] for o in self.output_keys}
[docs] def project_data(self, data, clamp=False):
if 'mel_projection' in self.layers:
data = self.layers['mel_projection'](data)
if clamp:
data = data.clamp(0.0, 1.0)
return data
[docs] def save(self, location, metadata=None):
"""
Saves a SeparationModel into a location into a dictionary with the
weights and model configuration.
Args:
location: (str) Where you want the model saved, as a path.
Returns:
"""
save_dict = {
'state_dict': self.state_dict(),
'config': json.dumps(self.config)
}
save_dict = {**save_dict, **(metadata if metadata else {})}
torch.save(save_dict, location)
return location
def __repr__(self):
output = super().__repr__()
num_parameters = 0
for p in self.parameters():
if p.requires_grad:
num_parameters += np.cumprod(p.size())[-1]
output += '\nNumber of parameters: %d' % num_parameters
return output