from .. import dataset, model
[docs]def load_dataset(dataset_class, dataset_folder, dataset_config):
"""
This is a helper function that looks in the
:py:mod:`src.dataset` module.
Args:
dataset_class (str): Name of the dataset class you want to
instantiate (e.g. Scaper, MixSourceFolder).
dataset_folder (str): Folder you want to load the data from.
dataset_config (dict): Configuration of the dataset
Returns:
:py:class:`torch.utils.data.Dataset`: Instantiated DatasetClass given the parameters.
"""
DatasetClass = getattr(dataset, dataset_class)
dataset_instance = DatasetClass(dataset_folder, dataset_config)
return dataset_instance
[docs]def load_model(model_config):
"""
Loads a deep :py:class:`SeparationModel` given a model configuration.
Args:
model_config (dict): Model configuration with a 'class' key. The rest of the keys
get put into the 'args'.
Returns:
:py:class:`SeparationModel`: Instantiated deep model given the parameters.
"""
model_class = model_config.pop('class', 'SeparationModel')
ModelClass = getattr(model, model_class)
if model_class == 'SeparationModel':
model_instance = ModelClass(model_config, extra_modules=model.extras)
else:
model_instance = ModelClass(model_config)
return model_instance