from runners.utils import load_yaml
from src import train, dataset, model, logging
from src.utils import loaders, seed
import glob
import pytest
import sys, os
import torch
import shutil
seed(0)
logger = logging.getLogger()
os.makedirs('tests/out/_test_train/logs/', exist_ok=True)
fh = logging.FileHandler(f'tests/out/_test_train/logs/output.txt')
logger.addHandler(fh)
def _load_dataset(config, split):
    config['dataset_config']['overwrite_cache'] = True
    config['dataset_config']['cache'] = 'tests/out/_test_dataset/'
    config['dataset_config']['fraction_of_dataset'] = .1
    dset = loaders.load_dataset(
            config['datasets'][split]['class'],
            config['datasets'][split]['folder'],
            config['dataset_config'],
        )
    return dset
paths_to_yml = list(glob.glob('./experiments/*.yml', recursive=False))
configs = [
    load_yaml(path_to_yml)
    for path_to_yml in paths_to_yml
]
[docs]@pytest.mark.parametrize("config", configs, ids=paths_to_yml)
def test_dataset(config):
    for split in config['datasets']:
        dset = _load_dataset(config, split)
        dset[0]
 
[docs]@pytest.mark.parametrize("config", configs, ids=paths_to_yml)
def test_model(config):
    if 'model_config' in config:
        model = loaders.load_model(config['model_config'])
 
[docs]@pytest.mark.parametrize("config", configs, ids=paths_to_yml)
def test_model_and_dataset_match(config):
    device = (
        torch.device('cuda') 
        if torch.cuda.is_available()
        else torch.device('cpu')
    )
    if 'datasets' in config and 'model_config' in config:
        for split in config['datasets']:
            dset = _load_dataset(config, split)
            data = dset[0]
            for key in data:
                data[key] = torch.from_numpy(
                    data[key]
                ).unsqueeze(0).float().to(device)
            model_instance = loaders.load_model(config['model_config'])
            model_instance = model_instance.to(device)
            output = model_instance(data)
 
[docs]@pytest.mark.parametrize("config", configs, ids=paths_to_yml)
def test_train(config):
    if 'train_config' in config:
        train_class = config['train_config'].pop('class')
        output_folder = 'tests/out/_test_train/'
        config['train_config']['num_epochs'] = 1
        TrainerClass = getattr(train, train_class)
        train_dataset = _load_dataset(config, 'train')
        val_dataset = _load_dataset(config, 'val')
        model_instance = loaders.load_model(config['model_config'])
        trainer = TrainerClass(
            output_folder,
            train_dataset,
            model_instance,
            config['train_config'],
            validation_data=val_dataset,
            use_tensorboard=True,
            experiment=None,
        )
        trainer.fit()