"""Input pipeline for DCASE 2018 Task 2 Baseline models.""" import functools import os import numpy as np from scipy.io import wavfile import tensorflow as tf from tensorflow.contrib.framework.python.ops import audio_ops as tf_audio # All input clips use a 44.1 kHz sample rate. SAMPLE_RATE = 44100 def clip_to_waveform(clip, clip_dir=None): """Decodes a WAV clip into a waveform tensor.""" # Decode the WAV-format clip into a waveform tensor where # the values lie in [-1, +1]. clip_path = tf.string_join([clip_dir, clip], separator=os.sep) clip_data = tf.read_file(clip_path) waveform, sr = tf_audio.decode_wav(clip_data) # Assert that the clip has the expected sample rate. check_sr = tf.assert_equal(sr, SAMPLE_RATE) # and that it is mono. check_channels = tf.assert_equal(tf.shape(waveform)[1], 1) with tf.control_dependencies([tf.group(check_sr, check_channels)]): return tf.squeeze(waveform) def clip_to_log_mel_examples(clip, clip_dir=None, hparams=None): """Decodes a WAV clip into a batch of log mel spectrum examples.""" # Decode WAV clip into waveform tensor. waveform = clip_to_waveform(clip, clip_dir=clip_dir) # Convert waveform into spectrogram using a Short-Time Fourier Transform. # Note that tf.contrib.signal.stft() uses a periodic Hann window by default. window_length_samples = int(round(SAMPLE_RATE * hparams.stft_window_seconds)) hop_length_samples = int(round(SAMPLE_RATE * hparams.stft_hop_seconds)) fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0))) magnitude_spectrogram = tf.abs(tf.contrib.signal.stft( signals=waveform, frame_length=window_length_samples, frame_step=hop_length_samples, fft_length=fft_length)) # Convert spectrogram into log mel spectrogram. num_spectrogram_bins = fft_length // 2 + 1 linear_to_mel_weight_matrix = tf.contrib.signal.linear_to_mel_weight_matrix( num_mel_bins=hparams.mel_bands, num_spectrogram_bins=num_spectrogram_bins, sample_rate=SAMPLE_RATE, lower_edge_hertz=hparams.mel_min_hz, upper_edge_hertz=hparams.mel_max_hz) mel_spectrogram = tf.matmul(magnitude_spectrogram, linear_to_mel_weight_matrix) log_mel_spectrogram = tf.log(mel_spectrogram + hparams.mel_log_offset) # Frame log mel spectrogram into examples. spectrogram_sr = 1 / hparams.stft_hop_seconds example_window_length_samples = int(round(spectrogram_sr * hparams.example_window_seconds)) example_hop_length_samples = int(round(spectrogram_sr * hparams.example_hop_seconds)) features = tf.contrib.signal.frame( signal=log_mel_spectrogram, frame_length=example_window_length_samples, frame_step=example_hop_length_samples, axis=0) return features def record_to_labeled_log_mel_examples(csv_record, clip_dir=None, hparams=None, label_class_index_table=None, num_classes=None): """Creates a batch of log mel spectrum examples from a training record. Args: csv_record: a line from the train.csv file downloaded from Kaggle. clip_dir: path to a directory containing clips referenced by csv_record. hparams: tf.contrib.training.HParams object containing model hyperparameters. label_class_index_table: a lookup table that represents the class map. num_classes: number of classes in the class map. Returns: features: Tensor containing a batch of log mel spectrum examples. labels: Tensor containing corresponding labels in 1-hot format. """ [clip, label, _] = tf.decode_csv(csv_record, record_defaults=[[''],[''],[0]]) features = clip_to_log_mel_examples(clip, clip_dir=clip_dir, hparams=hparams) class_index = label_class_index_table.lookup(label) label_onehot = tf.one_hot(class_index, num_classes) num_examples = tf.shape(features)[0] labels = tf.tile([label_onehot], [num_examples, 1]) return features, labels def get_class_map(class_map_path): """Constructs a class label lookup table from a class map.""" label_class_index_table = tf.contrib.lookup.HashTable( tf.contrib.lookup.TextFileInitializer( filename=class_map_path, key_dtype=tf.string, key_index=1, value_dtype=tf.int32, value_index=0, delimiter=','), default_value=-1) num_classes = len(open(class_map_path).readlines()) return label_class_index_table, num_classes def train_input(train_csv_path=None, train_clip_dir=None, class_map_path=None, hparams=None): """Creates training input pipeline. Args: train_csv_path: path to the train.csv file provided by Kaggle. train_clip_dir: path to the unzipped audio_train/ directory from the audio_train.zip file provided by Kaggle. class_map_path: path to the class map prepared from the training data. hparams: tf.contrib.training.HParams object containing model hyperparameters Returns: features: Tensor containing a batch of log mel spectrum examples. labels: Tensor containing corresponding labels in 1-hot format. num_classes: number of classes. iter_init: an initializer op for the iterator that provides features and labels, to be run before the input pipeline is read. """ label_class_index_table, num_classes = get_class_map(class_map_path) dataset = tf.data.TextLineDataset(train_csv_path) # Skip the header. dataset = dataset.skip(1) # Shuffle the list of clips. 10K is big enough to cover all clips. dataset = dataset.shuffle(buffer_size=10000) # Map each clip to a batch of framed log mel spectrum examples. dataset = dataset.map( map_func=functools.partial( record_to_labeled_log_mel_examples, clip_dir=train_clip_dir, hparams=hparams, label_class_index_table=label_class_index_table, num_classes=num_classes), # 4 is empirically chosen to use 4 logical CPU cores. Adjust as # needed if more or less resources are available. num_parallel_calls=4) # Unbatch so that we have a dataset of individual examples that we can then # shuffle for training. 20K should be enough to allow shuffling across a # few hundred clips which are already in random order. dataset = dataset.apply(tf.contrib.data.unbatch()) dataset = dataset.shuffle(buffer_size=20000) # Run until we have completed 100 epochs of the training set. dataset = dataset.repeat(100) # Batch examples. dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size=hparams.batch_size)) # Let the input pipeline run a few batches ahead so that the model is # never starved of data. dataset = dataset.prefetch(10) iterator = dataset.make_initializable_iterator() features, labels = iterator.get_next() return features, labels, num_classes, iterator.initializer