import re import json import pickle import numpy as np from pathlib import Path from scipy.stats.mstats import gmean from src.datasets import get_noisy_data_generator, get_folds_data, get_augment_folds_data_generator from src import config def gmean_preds_blend(probs_df_lst): blend_df = probs_df_lst[0] blend_values = np.stack([df.loc[blend_df.index.values].values for df in probs_df_lst], axis=0) blend_values = gmean(blend_values, axis=0) blend_df.values[:] = blend_values return blend_df def get_best_model_path(dir_path: Path, return_score=False): model_scores = [] for model_path in dir_path.glob('*.pth'): score = re.search(r'-(\d+(?:\.\d+)?).pth', str(model_path)) if score is not None: score = float(score.group(0)[1:-4]) model_scores.append((model_path, score)) model_score = sorted(model_scores, key=lambda x: x[1]) best_model_path = model_score[-1][0] if return_score: best_score = model_score[-1][1] return best_model_path, best_score else: return best_model_path def pickle_save(obj, filename): print(f"Pickle save to: {filename}") with open(filename, 'wb') as f: pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) def pickle_load(filename): print(f"Pickle load from: {filename}") with open(filename, 'rb') as f: return pickle.load(f) def load_folds_data(use_corrections=True): if use_corrections: with open(config.corrections_json_path) as file: corrections = json.load(file) print("Corrections:", corrections) pkl_name = f'{config.audio.get_hash(corrections=corrections)}.pkl' else: corrections = None pkl_name = f'{config.audio.get_hash()}.pkl' folds_data_pkl_path = config.folds_data_pkl_dir / pkl_name if folds_data_pkl_path.exists(): folds_data = pickle_load(folds_data_pkl_path) else: folds_data = get_folds_data(corrections) if not config.folds_data_pkl_dir.exists(): config.folds_data_pkl_dir.mkdir(parents=True, exist_ok=True) pickle_save(folds_data, folds_data_pkl_path) return folds_data def load_noisy_data(): with open(config.noisy_corrections_json_path) as file: corrections = json.load(file) pkl_name_glob = f'{config.audio.get_hash(corrections=corrections)}_*.pkl' pkl_paths = sorted(config.noisy_data_pkl_dir.glob(pkl_name_glob)) images_lst, targets_lst = [], [] if pkl_paths: for pkl_path in pkl_paths: data_batch = pickle_load(pkl_path) images_lst += data_batch[0] targets_lst += data_batch[1] else: if not config.noisy_data_pkl_dir.exists(): config.noisy_data_pkl_dir.mkdir(parents=True, exist_ok=True) for i, data_batch in enumerate(get_noisy_data_generator()): pkl_name = f'{config.audio.get_hash(corrections=corrections)}_{i:02}.pkl' noisy_data_pkl_path = config.noisy_data_pkl_dir / pkl_name pickle_save(data_batch, noisy_data_pkl_path) images_lst += data_batch[0] targets_lst += data_batch[1] return images_lst, targets_lst def load_augment_folds_data(time_stretch_lst, pitch_shift_lst): config_hash = config.audio.get_hash(time_stretch_lst=time_stretch_lst, pitch_shift_lst=pitch_shift_lst) pkl_name_glob = f'{config_hash}_*.pkl' pkl_paths = sorted(config.augment_folds_data_pkl_dir.glob(pkl_name_glob)) images_lst, targets_lst, folds_lst = [], [], [] if pkl_paths: for pkl_path in pkl_paths: data_batch = pickle_load(pkl_path) images_lst += data_batch[0] targets_lst += data_batch[1] folds_lst += data_batch[2] else: if not config.augment_folds_data_pkl_dir.exists(): config.augment_folds_data_pkl_dir.mkdir(parents=True, exist_ok=True) generator = get_augment_folds_data_generator(time_stretch_lst, pitch_shift_lst) for i, data_batch in enumerate(generator): pkl_name = f'{config_hash}_{i:02}.pkl' augment_data_pkl_path = config.augment_folds_data_pkl_dir / pkl_name pickle_save(data_batch, augment_data_pkl_path) images_lst += data_batch[0] targets_lst += data_batch[1] folds_lst += data_batch[2] return images_lst, targets_lst, folds_lst