Source code for scripts.visualize

from src.utils import loaders
from src import algorithms, model
from . import cmd, document_parser
from runners.experiment_utils import load_experiment
import inspect
import numpy as np
import os
import matplotlib.pyplot as plt
from argparse import ArgumentParser
import logging
import torch

[docs]def visualize(path_to_yml_file, file_names=[], eval_keys=['test']): """ Takes in a path to a yml file containing an experiment configuration and runs the algorithm specified in the experiment on a random file from the test dataset specified in the experiment. If the algorithm has plotting available, then plot is used to visualize the algorithm and save it to a figure. The associated audio is also saved. Args: path_to_yml_file (str): Path to the yml file that defines the experiment. The visualization will be placed into a "viz" folder in the same directory as the yml file. eval_keys (list): All of the dataset keys to be used to visualize the experiment. Will visualize for each eval_key in sequence. Defaults to ['test']. """ config, exp, path_to_yml_file = load_experiment(path_to_yml_file) algorithm_config = config['algorithm_config'] AlgorithmClass = getattr(algorithms, algorithm_config['class']) args = inspect.getfullargspec(AlgorithmClass)[0] if 'extra_modules' in args: algorithm_config['args']['extra_modules'] = model.extras if 'use_cuda' in args: algorithm_config['args']['use_cuda'] = torch.cuda.is_available() _datasets = {} for key in eval_keys: if key in config['datasets']: _datasets[key] = loaders.load_dataset( config['datasets'][key]['class'], config['datasets'][key]['folder'], config['dataset_config'] ) for key in _datasets: i = np.random.randint(len(_datasets[key])) file_names = file_names if file_names else [_datasets[key].files[i]] for file_name in file_names: try: logging.info(f'Visualizing {file_name}') folder = os.path.splitext(os.path.basename(file_name))[0] output_folder = os.path.join( config['info']['output_folder'], 'viz', key, folder) os.makedirs(output_folder, exist_ok=True) mixture = _datasets[key].load_audio_files(file_name)[0] logging.info(mixture) _algorithm = AlgorithmClass(mixture, **algorithm_config['args']) _algorithm.run() estimates = _algorithm.make_audio_signals() try: plt.figure(figsize=(20, 10)) _algorithm.plot() plt.savefig( os.path.join(output_folder, 'viz.png'), bbox_inches='tight', dpi=100) except: logging.error('Unable to plot.') mixture.write_audio_to_file( os.path.join(output_folder, f'mixture.wav') ) for i, e in enumerate(estimates): e.write_audio_to_file( os.path.join(output_folder, f'source{i}.wav') ) except: logging.error('File name not found.')
[docs]@document_parser('visualize', 'scripts.visualize.visualize') def build_parser(): parser = ArgumentParser() parser.add_argument( "-p", "--path_to_yml_file", type=str, required=True, help="""Path to the yml file that defines the experiment. The visualization will be placed into a "viz" folder in the same directory as the yml file. """ ) parser.add_argument( "-f", "--file_names", nargs='+', type=str, help="""Files to evaluate. Use only the base name of each file in the list that is being evaluated. """ ) parser.add_argument( "-e", "--eval_keys", nargs='+', type=str, default=['test'], help="""All of the dataset keys to be used to visualize the experiment. Will visualize for each eval_key in sequence. Defaults to ['test']. """ ) return parser
if __name__ == '__main__': cmd(visualize, build_parser)