import os import librosa import subprocess import tempfile import io import pysrt from pysrt import SubRipTime import string import random import chardet import re from datetime import timedelta import numpy as np import sklearn from .ffmpeg import Transcode from .log import logger class Media: """ Media class represents a media file on disk for which the content can be analyzed and retrieved. """ # List of supported media formats FORMATS = ['.mkv', '.mp4', '.wmv', '.avi', '.flv'] # The frequency of the generated audio FREQ = 16000 # The number of coefficients to extract from the mfcc N_MFCC = 13 # The number of samples in each mfcc coefficient HOP_LEN = 512.0 # The length (seconds) of each item in the mfcc analysis LEN_MFCC = HOP_LEN/FREQ def __init__(self, filepath, subtitles=None): prefix, ext = os.path.splitext(filepath) if ext == '.srt': return self.from_srt(filepath) if not ext: raise ValueError('unknown file: "{}"'.format(filepath)) if ext not in Media.FORMATS: raise ValueError('filetype {} not supported: "{}"'.format(ext, filepath)) self.__subtitles = subtitles self.filepath = os.path.abspath(filepath) self.filename = os.path.basename(prefix) self.extension = ext self.offset = timedelta() def from_srt(self, filepath): prefix, ext = os.path.splitext(filepath) if ext != '.srt': raise ValueError('filetype must be .srt format') prefix = os.path.basename(re.sub(r'\.\w\w$', '', prefix)) dir = os.path.dirname(filepath) for f in os.listdir(dir): _, ext = os.path.splitext(f) if f.startswith(prefix) and ext in Media.FORMATS: return self.__init__(os.path.join(dir, f), subtitles=[filepath]) raise ValueError('no media for subtitle: "{}"'.format(filepath)) def subtitles(self): if self.__subtitles is not None: for s in self.__subtitles: yield Subtitle(self, s) else: dir = os.path.dirname(self.filepath) for f in os.listdir(dir): if f.endswith('.srt') and f.startswith(self.filename): yield Subtitle(self, os.path.join(dir, f)) def mfcc(self, duration=60*15, seek=True): transcode = Transcode(self.filepath, duration=duration, seek=seek) self.offset = transcode.start print("Transcoding...") transcode.run() y, sr = librosa.load(transcode.output, sr=Media.FREQ) print("Analysing...") self.mfcc = librosa.feature.mfcc(y=y, sr=sr, hop_length=int(Media.HOP_LEN), n_mfcc=int(Media.N_MFCC) ) os.remove(transcode.output) return self.mfcc class Subtitle: """ Subtitle class represnets an .srt file on disk and provides functionality to inspect and manipulate the subtitle content """ def __init__(self, media, path): self.media = media self.path = path self.subs = pysrt.open(self.path, encoding=self._find_encoding()) def labels(self, subs=None): if self.media.mfcc is None: raise RuntimeError("Must analyse mfcc before generating labels") samples = len(self.media.mfcc[0]) labels = np.zeros(samples) for sub in self.subs if subs is None else subs: start = timeToPos(sub.start - self.offset()) end = timeToPos(sub.end - self.offset())+1 for i in range(start, end): if i >= 0 and i < len(labels): labels[i] = 1 return labels def _find_encoding(self): data = None with open(self.path, "rb") as f: data = f.read() det = chardet.detect(data) return det.get("encoding") def offset(self): d = self.media.offset hours, remainder = divmod(d.seconds, 3600) minutes, seconds = divmod(remainder, 60) return SubRipTime( hours=hours, minutes=minutes, seconds=seconds, milliseconds=d.microseconds/1000 ) def logloss(self, pred, actual, margin=12): blocks = secondsToBlocks(margin) logloss = np.ones(blocks*2) indices = np.ones(blocks*2) nonzero = np.nonzero(actual)[0] begin = max(nonzero[0]-blocks, 0) end = min(nonzero[-1]+blocks, len(actual)-1) pred = pred[begin:end] actual = actual[begin:end] for i, offset in enumerate(range(-blocks, blocks)): snippet = np.roll(actual, offset) try: logloss[i] = sklearn.metrics.log_loss(snippet[blocks:-blocks], pred[blocks:-blocks]) except (ValueError, RuntimeWarning): pass indices[i] = offset return indices, logloss def sync(self, net, safe=True, margin=12, plot=True): secs = 0.0 labels = self.labels() mfcc = self.media.mfcc.T mfcc = mfcc[..., np.newaxis] pred = net.predict(mfcc) x, y = self.logloss(pred, labels, margin=margin) accept = True if safe: mean = np.mean(y) sd = np.std(y) accept = np.min(y) < mean - sd if accept: secs = blocksToSeconds(x[np.argmin(y)]) print("Shift {} seconds:".format(secs)) self.subs.shift(seconds=secs) self.subs.save(self.path, encoding='utf-8') if secs != 0.0: logger.info('{}: {}s'.format(self.path, secs)) if plot: self.plot_logloss(x, y) return secs def sync_all(self, net, margin=16, plot=True): secs = 0.0 mfcc = self.media.mfcc.T mfcc = mfcc[..., np.newaxis] pred = net.predict(mfcc) print("Fitting...") self.__sync_all_rec(self.subs, pred) self.clean() self.subs.save(self.path, encoding='utf-8') def __sync_all_rec(self, subs, pred, margin=16): if len(subs) < 3: return labels = self.labels(subs=subs) if np.unique(labels).size <= 1: return x, y = self.logloss(pred, labels, margin=max(margin, 0.25)) #self.plot_logloss(x,y) #self.plot_labels(labels, pred) secs = blocksToSeconds(x[np.argmin(y)]) subs.shift(seconds=secs) # call recursively middle = subs[len(subs)//2] left = subs.slice(ends_before=middle.start) right = subs.slice(starts_after=middle.start) self.__sync_all_rec(left, pred, margin=margin/2) self.__sync_all_rec(right, pred, margin=margin/2) def clean(self): for i, s in enumerate(self.subs): if i >= len(self.subs)-1: return next = self.subs[i+1] if s.end > next.start: s.end = next.start def plot_logloss(self, x, y): import matplotlib.pyplot as plt plt.figure() plt.plot(x, y) plt.title('logloss over shifts') plt.ylabel('logloss') plt.xlabel('shifts') plt.legend(['logloss'], loc='upper left') plt.show() def plot_labels(self, labels, pred): import matplotlib.pyplot as plt plt.figure() plt.plot([i for i in range(0,len(labels))], labels, label='labels') plt.title('labels vs predictions') plt.ylabel('value') plt.xlabel('time') plt.legend(['labels'], loc='upper left') plt.figure() plt.plot([i for i in range(0,len(pred))], pred, label='pred') plt.title('labels vs predictions') plt.ylabel('value') plt.xlabel('time') plt.legend(['pred'], loc='upper left') plt.show() # Convert timestamp to seconds def timeToSec(t): total_sec = float(t.milliseconds)/1000 total_sec += t.seconds total_sec += t.minutes*60 total_sec += t.hours*60*60 return total_sec # Return timestamp from cell position def timeToPos(t, freq=Media.FREQ, hop_len=Media.HOP_LEN): return round(timeToSec(t)/(hop_len/freq)) def secondsToBlocks(s, hop_len=Media.HOP_LEN, freq=Media.FREQ): return int(float(s)/(hop_len/freq)) def blocksToSeconds(h, freq=Media.FREQ, hop_len=Media.HOP_LEN): return float(h)*(hop_len/freq)