import os import numpy as np import matplotlib.pyplot as plt import soundfile as sf from scipy.io import wavfile def load_path2gt(paths_file, config): """ Given the path, construct the ground truth vectors. This function heavily relies on path2gt_datasets(.), where the relation between the path and ground truth are defined. """ paths = list() path2gt = dict() path2onehot = dict() # REMOVE IF NOT USED! pf = open(paths_file) for path in pf.readlines(): path = path.rstrip('\n') paths.append(path) label = path2gt_datasets(path, config['dataset']) path2gt[path] = label path2onehot[path] = label2onehot(label, config['num_classes_dataset']) return paths, path2gt, path2onehot def label2onehot(label, num_classes): """ Convert class label to one hot vector. Example: label2onehot(label=2, num_classes=5) > array([0., 0., 1., 0., 0.]) """ onehot = np.zeros(num_classes) onehot[label] = 1 return onehot def path2gt_datasets(path, dataset): """ Given the audio path, it returns the ground truth label. Define HERE a new dataset to employ this code with other data. """ if dataset == 'GTZAN': if 'blues' in path: return 0 elif 'classical' in path: return 1 elif 'country' in path: return 2 elif 'disco' in path: return 3 elif 'hiphop' in path: return 4 elif 'jazz' in path: return 5 elif 'metal' in path: return 6 elif 'pop' in path: return 7 elif 'reggae' in path: return 8 elif 'rock' in path: return 9 else: print('Did not find the corresponding ground truth (' + str(path) + ')!') else: print('Did not find the implementation of ' + str(dataset) + ' dataset!') def matrix_visualization(matrix,title=None): """ Visualize 2D matrices like spectrograms or feature maps. """ plt.figure() plt.imshow(np.flipud(matrix.T),interpolation=None) plt.colorbar() if title!=None: plt.title(title) plt.show() def wavefile_to_waveform(wav_file, features_type): data, sr = sf.read(wav_file) if features_type == 'vggish': tmp_name = str(int(np.random.rand(1)*1000000)) + '.wav' sf.write(tmp_name, data, sr, subtype='PCM_16') sr, wav_data = wavfile.read(tmp_name) os.remove(tmp_name) # sr, wav_data = wavfile.read(wav_file) # as done in VGGish Audioset assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype data = wav_data / 32768.0 # Convert to [-1.0, +1.0] # at least one second of samples, if not repead-pad src_repeat = data while (src_repeat.shape[0] < sr): src_repeat = np.concatenate((src_repeat, data), axis=0) data = src_repeat[:sr] return data, sr