# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== r"""Simple speech recognition to spot a limited number of keywords. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import os import sys import numpy as np import tensorflow as tf from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio from tensorflow.python.ops import io_ops from tensorflow.python.platform import gfile from speech import input_data from speech import models FLAGS = None class AudioProcessor(object): """Handles loading, partitioning, and preparing audio training data.""" def __init__(self, data_dir, model_settings, feature_scaling=''): self.data_dir = data_dir self.prepare_data_index() self.prepare_processing_graph(model_settings) self.feature_scaling = feature_scaling def prepare_data_index(self): # Look through all the subfolders to find audio samples search_path = os.path.join(self.data_dir, '*', '*.wav') self.data_indexs = [] for wav_path in gfile.Glob(search_path): self.data_indexs.append(wav_path) def prepare_processing_graph(self, model_settings): """Builds a TensorFlow graph to apply the input distortions. Creates a graph that loads a WAVE file, decodes it, scales the volume, shifts it in time, adds in background noise, calculates a spectrogram, and then builds an MFCC fingerprint from that. This must be called with an active TensorFlow session running, and it creates multiple placeholder inputs, and one output: - wav_filename_placeholder_: Filename of the WAV to load. - foreground_volume_placeholder_: How loud the main clip should be. - time_shift_padding_placeholder_: Where to pad the clip. - time_shift_offset_placeholder_: How much to move the clip in time. - background_data_placeholder_: PCM sample data for background noise. - background_volume_placeholder_: Loudness of mixed-in background. - mfcc_: Output 2D fingerprint of processed audio. Args: model_settings: Information about the current model being trained. """ desired_samples = model_settings['desired_samples'] self.wav_filename_placeholder_ = tf.placeholder(tf.string, []) wav_loader = io_ops.read_file(self.wav_filename_placeholder_) wav_decoder = contrib_audio.decode_wav( wav_loader, desired_channels=1, desired_samples=desired_samples) # Allow the audio sample's volume to be adjusted. self.foreground_volume_placeholder_ = tf.placeholder(tf.float32, []) scaled_foreground = tf.multiply(wav_decoder.audio, self.foreground_volume_placeholder_) # Shift the sample's start position, and pad any gaps with zeros. self.time_shift_padding_placeholder_ = tf.placeholder(tf.int32, [2, 2]) self.time_shift_offset_placeholder_ = tf.placeholder(tf.int32, [2]) padded_foreground = tf.pad( scaled_foreground, self.time_shift_padding_placeholder_, mode='CONSTANT') sliced_foreground = tf.slice(padded_foreground, self.time_shift_offset_placeholder_, [desired_samples, -1]) mel_bias_ = tf.contrib.signal.linear_to_mel_weight_matrix(num_mel_bins=model_settings['dct_coefficient_count'], num_spectrogram_bins=int(2048 / 2 + 1), sample_rate=model_settings['sample_rate'], lower_edge_hertz=125, upper_edge_hertz=float( model_settings['sample_rate'] / 2 - 200)) spectrogram = tf.abs(tf.contrib.signal.stft(tf.transpose(sliced_foreground), model_settings['window_size_samples'], model_settings['window_stride_samples'], fft_length=2048, window_fn=tf.contrib.signal.hann_window, pad_end=False)) S = tf.matmul(tf.reshape(tf.pow(spectrogram, 2), [-1, 1025]), mel_bias_) log_mel_spectrograms = tf.log(tf.maximum(S, 1e-7)) if model_settings['feature_type'] == 'fbank': self.mfcc_ = log_mel_spectrograms elif model_settings['feature_type'] == 'mfcc': # Compute MFCCs from log_mel_spectrograms. self.mfcc_ = tf.contrib.signal.mfccs_from_log_mel_spectrograms(log_mel_spectrograms) else: raise ValueError("not supported feature_type: {}".format(model_settings['feature_type'])) def set_size(self): """Calculates the number of samples in the dataset partition. Returns: Number of samples in the partition. """ return len(self.data_indexs) def get_data(self, how_many, offset, model_settings, sess): """Gather samples from the data set, applying transformations as needed. Returns: List of sample data for the transformed samples, and wav files name. """ # Pick one of the partitions to choose samples from. candidates = self.data_indexs if how_many == -1: sample_count = len(candidates) else: sample_count = max(0, min(how_many, len(candidates) - offset)) # Data and labels will be populated and returned. data = np.zeros((sample_count, model_settings['fingerprint_size'])) wav_files = [] # Use the processing graph we created earlier to repeatedly to generate the # final output sample data we'll use in training. for i in xrange(offset, offset + sample_count): # Pick which audio sample to use. sample_file = candidates[i] time_shift_amount = 0 if time_shift_amount > 0: time_shift_padding = [[time_shift_amount, 0], [0, 0]] time_shift_offset = [0, 0] else: time_shift_padding = [[0, -time_shift_amount], [0, 0]] time_shift_offset = [-time_shift_amount, 0] input_dict = {self.wav_filename_placeholder_: sample_file, self.time_shift_padding_placeholder_: time_shift_padding, self.time_shift_offset_placeholder_: time_shift_offset} input_dict[self.foreground_volume_placeholder_] = 1 # Run the graph to produce the output audio. data[i - offset, :] = sess.run(self.mfcc_, feed_dict=input_dict).flatten() wav_files.append(sample_file) return input_data.AudioProcessor.apply_feature_scaling(data, self.feature_scaling, model_settings['dct_coefficient_count']), wav_files def main(_): # We want to see all the logging messages for this tutorial. tf.logging.set_verbosity(tf.logging.INFO) sess = tf.InteractiveSession() model_settings = models.prepare_model_settings( len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))), FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.dct_coefficient_count, FLAGS.feature_type) audio_processor = AudioProcessor(FLAGS.data_dir, model_settings) fingerprint_size = model_settings['fingerprint_size'] fingerprint_input = tf.placeholder(tf.float32, [None, fingerprint_size], name='fingerprint_input') logits = models.create_model( fingerprint_input, model_settings, FLAGS.model_architecture, hparam_string=FLAGS.hparams, is_training=False) softmax = tf.nn.softmax(logits, name='labels_softmax') tf.global_variables_initializer().run() checkpoint_path = tf.train.latest_checkpoint(FLAGS.train_dir) if checkpoint_path: models.load_variables_from_checkpoint(sess, checkpoint_path) else: tf.logging.fatal("Not find checkpoint.") set_size = audio_processor.set_size() tf.logging.info('set_size=%d', set_size) with gfile.GFile(FLAGS.output_csv, 'w') as wf: wf.write("fname,{}\n".format(','.join(input_data.prepare_words_list(FLAGS.wanted_words.split(','))))) for i in xrange(0, set_size, FLAGS.batch_size): test_fingerprints, test_wavfiles = audio_processor.get_data( FLAGS.batch_size, i, model_settings, sess) probs = sess.run(softmax, feed_dict={ fingerprint_input: test_fingerprints, }) for k, wav_file in enumerate(test_wavfiles): wf.write("%s,%s\n" % (wav_file.split('/')[-1], ','.join([str(v) for v in probs[k]]))) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '--data_dir', type=str, default='/tmp/speech_dataset/', help="""\ Where to download the speech training data to. """) parser.add_argument( '--time_shift_ms', type=float, default=100.0, help="""\ Range to randomly shift the training audio by in time. """) parser.add_argument( '--sample_rate', type=int, default=16000, help='Expected sample rate of the wavs', ) parser.add_argument( '--clip_duration_ms', type=int, default=1000, help='Expected duration in milliseconds of the wavs', ) parser.add_argument( '--window_size_ms', type=float, default=30.0, help='How long each spectrogram timeslice is', ) parser.add_argument( '--window_stride_ms', type=float, default=10.0, help='How long each spectrogram timeslice is', ) parser.add_argument( '--dct_coefficient_count', type=int, default=40, help='How many bins to use for the MFCC fingerprint', ) parser.add_argument( '--batch_size', type=int, default=128, help='How many items to train with at once', ) parser.add_argument( '--wanted_words', type=str, default='yes,no,up,down,left,right,on,off,stop,go', help='Words to use (others will be added to an unknown label)', ) parser.add_argument( '--train_dir', type=str, default='/tmp/speech_commands_train', help='Directory to write event logs and checkpoint.') parser.add_argument( '--model_architecture', type=str, default='conv', help='What model architecture to use') parser.add_argument( '--hparams', type=str, default='', help='Hyper parameters string') parser.add_argument( '--output_csv', type=str, default='', help='Output file name') parser.add_argument( '--feature_scaling', type=str, default='', # '' 'cmvn' help='Feature normalization') parser.add_argument( '--feature_type', type=str, default='mfcc', # help='Feature type (e.g. mfcc or fbank)') FLAGS, unparsed = parser.parse_known_args() if not FLAGS.output_csv: raise ValueError("must set --output_csv") tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)