#!python import os import sys from pathlib import Path from tqdm import tqdm import multiprocessing as mp import numpy as np import torch import torch.nn as nn import torch.distributed as dist import torchvision.utils as tvu import torchnet as tnt import Levenshtein as Lev from asr.utils.logger import logger from asr.utils.misc import onehot2int, int2onehot, remove_duplicates, get_model_file_path from asr.utils.adamw import AdamW from asr.utils.lr_scheduler import CosineAnnealingWithRestartsLR from asr.utils import params from asr.kaldi.latgen import LatGenCTCDecoder OPTIMIZER_TYPES = set([ "sgd", "sgdr", "adam", "adamw", "adamwr", "rmsprop", ]) torch.backends.cudnn.enabled = True torch.backends.cudnn.deterministic = False torch.multiprocessing.freeze_support() def init_distributed(use_cuda, backend="nccl", init="slurm", local_rank=-1): #try: # mp.set_start_method('spawn') # spawn, forkserver, and fork #except RuntimeError: # pass try: if local_rank == -1: if init == "slurm": rank = int(os.environ['SLURM_PROCID']) world_size = int(os.environ['SLURM_NTASKS']) local_rank = int(os.environ['SLURM_LOCALID']) #maser_node = os.environ['SLURM_TOPOLOGY_ADDR'] #maser_port = '23456' elif init == "ompi": rank = int(os.environ['OMPI_COMM_WORLD_RANK']) world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) if use_cuda: device = local_rank % torch.cuda.device_count() torch.cuda.set_device(device) print(f"set cuda device to cuda:{device}") master_node = os.environ["MASTER_ADDR"] master_port = os.environ["MASTER_PORT"] init_method = f"tcp://{master_node}:{master_port}" #init_method = "env://" dist.init_process_group(backend=backend, init_method=init_method, world_size=world_size, rank=rank) print(f"initialized as {rank}/{world_size} via {init_method}") else: if use_cuda: torch.cuda.set_device(local_rank) print(f"set cuda device to cuda:{local_rank}") dist.init_process_group(backend=backend, init_method="env://") print(f"initialized as {dist.get_rank()}/{dist.get_world_size()} via env://") except Exception as e: print(f"initialized as single process") def is_distributed(): try: return (dist.get_world_size() > 1) except: return False def get_rank(): try: return dist.get_rank() except: return None def set_seed(seed=None): if seed is not None: logger.info(f"set random seed to {seed}") torch.manual_seed(seed) np.random.seed(seed) if args.use_cuda: torch.cuda.manual_seed(seed) def get_amp_handle(args): if not args.use_cuda: args.fp16 = False if args.fp16: from apex import amp amp_handle = amp.init(enabled=True, enable_caching=True, verbose=False) return amp_handle else: return None class Trainer: def __init__(self, model, amp_handle=None, init_lr=1e-2, max_norm=100, use_cuda=False, fp16=False, log_dir='logs', model_prefix='model', checkpoint=False, continue_from=None, opt_type=None, *args, **kwargs): if fp16: import apex.parallel from apex import amp if not use_cuda: raise RuntimeError self.amp_handle = amp_handle # training parameters self.init_lr = init_lr self.max_norm = max_norm self.use_cuda = use_cuda self.fp16 = fp16 self.log_dir = log_dir self.model_prefix = model_prefix self.checkpoint = checkpoint self.opt_type = opt_type self.epoch = 0 self.states = None self.global_step = 0 # for tensorboard # load from pre-trained model if needed if continue_from is not None: self.load(continue_from) # setup model self.model = model if self.use_cuda: logger.debug("using cuda") self.model.cuda() # setup loss loss = kwargs.get('loss', None) self.loss = (nn.CTCLoss(blank=0, reduction='mean') if loss is None else loss) # setup optimizer self.optimizer = None self.lr_scheduler = None if opt_type is not None: assert opt_type in OPTIMIZER_TYPES parameters = self.model.parameters() if opt_type == "sgdr": logger.debug("using SGDR") self.optimizer = torch.optim.SGD(parameters, lr=self.init_lr, momentum=0.9, weight_decay=1e-4) #self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=0.5) self.lr_scheduler = CosineAnnealingWithRestartsLR(self.optimizer, T_max=5, T_mult=2) elif opt_type == "adamw": logger.debug("using AdamW") self.optimizer = AdamW(parameters, lr=self.init_lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-4, amsgrad=True) elif opt_type == "adamwr": logger.debug("using AdamWR") self.optimizer = AdamW(parameters, lr=self.init_lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-4, amsgrad=True) self.lr_scheduler = CosineAnnealingWithRestartsLR(self.optimizer, T_max=5, T_mult=2) elif opt_type == "adam": logger.debug("using Adam") self.optimizer = torch.optim.Adam(parameters, lr=self.init_lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-4) elif opt_type == "rmsprop": logger.debug("using RMSprop") self.optimizer = torch.optim.RMSprop(parameters, lr=self.init_lr, alpha=0.95, eps=1e-8, weight_decay=1e-4, centered=True) # setup decoder for test self.decoder = LatGenCTCDecoder() self.labeler = self.decoder.labeler # FP16 and distributed after load if self.fp16: #self.model = network_to_half(self.model) #self.optimizer = FP16_Optimizer(self.optimizer, static_loss_scale=128.) self.optimizer = self.amp_handle.wrap_optimizer(self.optimizer) if is_distributed(): if self.use_cuda: local_rank = torch.cuda.current_device() if fp16: self.model = apex.parallel.DistributedDataParallel(self.model) else: self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[local_rank], output_device=local_rank) else: self.model = nn.parallel.DistributedDataParallel(self.model) if self.states is not None: self.restore_state() def __get_model_name(self, desc): return str(get_model_file_path(self.log_dir, self.model_prefix, desc)) def __remove_ckpt_files(self, epoch): for ckpt in Path(self.log_dir).rglob(f"*_epoch_{epoch:03d}_ckpt_*"): ckpt.unlink() def train_loop_before_hook(self): pass def train_loop_checkpoint_hook(self): pass def train_loop_after_hook(self): pass def unit_train(self, data): raise NotImplementedError def train_epoch(self, data_loader): self.model.train() meter_loss = tnt.meter.MovingAverageValueMeter(len(data_loader) // 100 + 1) #meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True) #meter_confusion = tnt.meter.ConfusionMeter(params.NUM_CTC_LABELS, normalized=True) if self.lr_scheduler is not None: self.lr_scheduler.step() logger.debug(f"current lr = {self.optimizer.param_groups[0]['lr']:.3e}") if is_distributed() and data_loader.sampler is not None: data_loader.sampler.set_epoch(self.epoch) ckpt_step = 0.1 ckpts = iter(len(data_loader) * np.arange(ckpt_step, 1 + ckpt_step, ckpt_step)) def plot_graphs(loss, data_iter=0, title="train", stats=False): #if self.lr_scheduler is not None: # self.lr_scheduler.step() x = self.epoch + data_iter / len(data_loader) self.global_step = int(x / ckpt_step) if logger.visdom is not None: opts = { 'xlabel': 'epoch', 'ylabel': 'loss', } logger.visdom.add_point(title=title, x=x, y=loss, **opts) if logger.tensorboard is not None: #logger.tensorboard.add_graph(self.model, xs) #xs_img = tvu.make_grid(xs[0, 0], normalize=True, scale_each=True) #logger.tensorboard.add_image('xs', self.global_step, xs_img) #ys_hat_img = tvu.make_grid(ys_hat[0].transpose(0, 1), normalize=True, scale_each=True) #logger.tensorboard.add_image('ys_hat', self.global_step, ys_hat_img) logger.tensorboard.add_scalars(title, self.global_step, { 'loss': loss, }) if stats: for name, param in self.model.named_parameters(): logger.tensorboard.add_histogram(name, self.global_step, param.clone().cpu().data.numpy()) self.train_loop_before_hook() ckpt = next(ckpts) t = tqdm(enumerate(data_loader), total=len(data_loader), desc="training", ncols=params.NCOLS) for i, (data) in t: loss_value = self.unit_train(data) if loss_value is not None: meter_loss.add(loss_value) t.set_description(f"training (loss: {meter_loss.value()[0]:.3f})") t.refresh() #self.meter_accuracy.add(ys_int, ys) #self.meter_confusion.add(ys_int, ys) if i > ckpt: plot_graphs(meter_loss.value()[0], i) if self.checkpoint: logger.info(f"training loss at epoch_{self.epoch:03d}_ckpt_{i:07d}: " f"{meter_loss.value()[0]:5.3f}") if not is_distributed() or (is_distributed() and dist.get_rank() == 0): self.save(self.__get_model_name(f"epoch_{self.epoch:03d}_ckpt_{i:07d}")) self.train_loop_checkpoint_hook() ckpt = next(ckpts) self.epoch += 1 logger.info(f"epoch {self.epoch:03d}: " f"training loss {meter_loss.value()[0]:5.3f} ") #f"training accuracy {meter_accuracy.value()[0]:6.3f}") if not is_distributed() or (is_distributed() and dist.get_rank() == 0): self.save(self.__get_model_name(f"epoch_{self.epoch:03d}")) self.__remove_ckpt_files(self.epoch-1) plot_graphs(meter_loss.value()[0], stats=True) self.train_loop_after_hook() def unit_validate(self, data): raise NotImplementedError def validate(self, data_loader): "validate with label error rate by the edit distance between hyps and refs" self.model.eval() with torch.no_grad(): N, D = 0, 0 t = tqdm(enumerate(data_loader), total=len(data_loader), desc="validating", ncols=params.NCOLS) for i, (data) in t: hyps, refs = self.unit_validate(data) # calculate ler N += self.edit_distance(refs, hyps) D += sum(len(r) for r in refs) ler = N * 100. / D t.set_description(f"validating (LER: {ler:.2f} %)") t.refresh() logger.info(f"validating at epoch {self.epoch:03d}: LER {ler:.2f} %") title = f"validate" x = self.epoch - 1 + i / len(data_loader) if logger.visdom is not None: opts = { 'xlabel': 'epoch', 'ylabel': 'LER', } logger.visdom.add_point(title=title, x=x, y=ler, **opts) if logger.tensorboard is not None: logger.tensorboard.add_scalars(title, self.global_step, { 'LER': ler, }) def unit_test(self, data): raise NotImplementedError def test(self, data_loader): "test with word error rate by the edit distance between hyps and refs" self.model.eval() with torch.no_grad(): N, D = 0, 0 t = tqdm(enumerate(data_loader), total=len(data_loader), desc="testing", ncols=params.NCOLS) for i, (data) in t: hyps, refs = self.unit_test(data) # calculate wer N += self.edit_distance(refs, hyps) D += sum(len(r) for r in refs) wer = N * 100. / D t.set_description(f"testing (WER: {wer:.2f} %)") t.refresh() logger.info(f"testing at epoch {self.epoch:03d}: WER {wer:.2f} %") def edit_distance(self, refs, hyps): assert len(refs) == len(hyps) n = 0 for ref, hyp in zip(refs, hyps): r = [chr(c) for c in ref] h = [chr(c) for c in hyp] n += Lev.distance(''.join(r), ''.join(h)) return n def target_to_loglikes(self, ys, label_lens): max_len = max(label_lens.tolist()) num_classes = self.labeler.get_num_labels() ys_hat = [torch.cat((torch.zeros(1).int(), ys[s:s+l], torch.zeros(max_len-l).int())) for s, l in zip([0]+label_lens[:-1].cumsum(0).tolist(), label_lens.tolist())] ys_hat = [int2onehot(torch.IntTensor(z), num_classes, floor=1e-3) for z in ys_hat] ys_hat = torch.stack(ys_hat) ys_hat = torch.log(ys_hat) return ys_hat def save_hook(self): pass def save(self, file_path, **kwargs): Path(file_path).parent.mkdir(mode=0o755, parents=True, exist_ok=True) logger.debug(f"saving the model to {file_path}") if self.states is None: self.states = dict() self.states.update(kwargs) self.states["epoch"] = self.epoch self.states["opt_type"] = self.opt_type if is_distributed(): model_state_dict = self.model.state_dict() strip_prefix = 9 if self.fp16 else 7 # remove "module.1." prefix from keys self.states["model"] = {k[strip_prefix:]: v for k, v in model_state_dict.items()} else: self.states["model"] = self.model.state_dict() self.states["optimizer"] = self.optimizer.state_dict() if self.lr_scheduler is not None: self.states["lr_scheduler"] = self.lr_scheduler.state_dict() self.save_hook() torch.save(self.states, file_path) def load(self, file_path): if isinstance(file_path, str): file_path = Path(file_path) if not file_path.exists(): logger.error(f"no such file {file_path} exists") sys.exit(1) logger.debug(f"loading the model from {file_path}") to_device = f"cuda:{torch.cuda.current_device()}" if self.use_cuda else "cpu" self.states = torch.load(file_path, map_location=to_device) def restore_state(self): self.epoch = self.states["epoch"] self.global_step = self.epoch * 10 if is_distributed(): self.model.load_state_dict({f"module.{k}": v for k, v in self.states["model"].items()}) else: self.model.load_state_dict(self.states["model"]) if "opt_type" in self.states and self.opt_type == self.states["opt_type"]: self.optimizer.load_state_dict(self.states["optimizer"]) if self.lr_scheduler is not None and "lr_scheduler" in self.states: self.lr_scheduler.load_state_dict(self.states["lr_scheduler"]) #for _ in range(self.epoch-1): # self.lr_scheduler.step() class NonSplitTrainer(Trainer): """training model for overall utterance spectrogram as a single image""" def unit_train(self, data): xs, ys, frame_lens, label_lens, filenames, _ = data try: batch_size = xs.size(0) if self.use_cuda: xs = xs.cuda(non_blocking=True) ys_hat, frame_lens = self.model(xs, frame_lens) if self.fp16: ys_hat = ys_hat.float() ys_hat = ys_hat.transpose(0, 1).contiguous() # TxNxH #torch.set_printoptions(threshold=5000000) #print(ys_hat.shape, frame_lens, ys.shape, label_lens) #print(onehot2int(ys_hat).squeeze(), ys) loss = self.loss(ys_hat, ys, frame_lens, label_lens) if torch.isnan(loss) or loss.item() == float("inf") or loss.item() == -float("inf"): logger.warning("received an nan/inf loss: probably frame_lens < label_lens or the learning rate is too high") #raise RuntimeError return None if frame_lens.cpu().lt(2*label_lens).nonzero().numel(): logger.debug("the batch includes a data with frame_lens < 2*label_lens: set loss to zero") loss.mul_(0) loss_value = loss.item() self.optimizer.zero_grad() if self.fp16: #self.optimizer.backward(loss) #self.optimizer.clip_master_grads(self.max_norm) with self.optimizer.scale_loss(loss) as scaled_loss: scaled_loss.backward() else: loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), self.max_norm) self.optimizer.step() if self.use_cuda: torch.cuda.synchronize() del loss return loss_value except Exception as e: print(e) print(filenames, frame_lens, label_lens) raise def unit_validate(self, data): xs, ys, frame_lens, label_lens, filenames, _ = data if self.use_cuda: xs = xs.cuda(non_blocking=True) ys_hat, frame_lens = self.model(xs, frame_lens) if self.fp16: ys_hat = ys_hat.float() # convert likes to ctc labels hyps = [onehot2int(yh[:s]).squeeze() for yh, s in zip(ys_hat, frame_lens)] hyps = [remove_duplicates(h, blank=0) for h in hyps] # slice the targets pos = torch.cat((torch.zeros((1, ), dtype=torch.long), torch.cumsum(label_lens, dim=0))) refs = [ys[s:l] for s, l in zip(pos[:-1], pos[1:])] return hyps, refs def unit_test(self, data, target_test=False): xs, ys, frame_lens, label_lens, filenames, texts = data if not target_test: if self.use_cuda: xs = xs.cuda(non_blocking=True) ys_hat, frame_lens = self.model(xs, frame_lens) if self.fp16: ys_hat = ys_hat.float() else: ys_hat = self.target_to_loglikes(ys, label_lens) # latgen decoding if self.use_cuda: ys_hat = ys_hat.cpu() words, alignment, w_sizes, a_sizes = self.decoder(ys_hat, frame_lens) w2i = self.labeler.word2idx num_words = self.labeler.get_num_words() words.masked_fill_(words.ge(num_words), w2i('<unk>')) words.masked_fill_(words.lt(0), w2i('<unk>')) hyps = [w[:s] for w, s in zip(words, w_sizes)] # convert target texts to word indices refs = [[w2i(w.strip()) for w in t.strip().split()] for t in texts] return hyps, refs class SplitTrainer(Trainer): """ training model for splitting utterance into multiple images single image stands for localized timing segment corresponding to frame output """ def unit_train(self, data): xs, ys, frame_lens, label_lens, filenames, _ = data try: if self.use_cuda: xs = xs.cuda(non_blocking=True) ys_hat = self.model(xs) if self.fp16: ys_hat = ys_hat.float() ys_hat = ys_hat.unsqueeze(dim=0).transpose(1, 2) pos = torch.cat((torch.zeros((1, ), dtype=torch.long), torch.cumsum(frame_lens, dim=0))) ys_hats = [ys_hat.narrow(2, p, l).clone() for p, l in zip(pos[:-1], frame_lens)] max_len = torch.max(frame_lens) ys_hats = [nn.ConstantPad1d((0, max_len-yh.size(2)), 0)(yh) for yh in ys_hats] ys_hat = torch.cat(ys_hats).transpose(1, 2).transpose(0, 1) loss = self.loss(ys_hat, ys, frame_lens, label_lens) loss_value = loss.item() self.optimizer.zero_grad() if self.fp16: #self.optimizer.backward(loss) #self.optimizer.clip_master_grads(self.max_norm) with self.optimizer.scale_loss(loss) as scaled_loss: scaled_loss.backward() else: loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), self.max_norm) self.optimizer.step() del loss except Exception as e: print(filenames, frame_lens, label_lens) raise return loss_value def unit_validate(self, data): xs, ys, frame_lens, label_lens, filenames, _ = data if self.use_cuda: xs = xs.cuda(non_blocking=True) ys_hat = self.model(xs) if self.fp16: ys_hat = ys_hat.float() pos = torch.cat((torch.zeros((1, ), dtype=torch.long), torch.cumsum(frame_lens, dim=0))) ys_hat = [ys_hat.narrow(0, p, l).clone() for p, l in zip(pos[:-1], frame_lens)] # convert likes to ctc labels hyps = [onehot2int(yh[:s]).squeeze() for yh, s in zip(ys_hat, frame_lens)] hyps = [remove_duplicates(h, blank=0) for h in hyps] # slice the targets pos = torch.cat((torch.zeros((1, ), dtype=torch.long), torch.cumsum(label_lens, dim=0))) refs = [ys[s:l] for s, l in zip(pos[:-1], pos[1:])] return hyps, refs def unit_test(self, data, target_test=False): xs, ys, frame_lens, label_lens, filenames, texts = data if not target_test: if self.use_cuda: xs = xs.cuda(non_blocking=True) ys_hat = self.model(xs) if self.fp16: ys_hat = ys_hat.float() ys_hat = ys_hat.unsqueeze(dim=0).transpose(1, 2) pos = torch.cat((torch.zeros((1, ), dtype=torch.long), torch.cumsum(frame_lens, dim=0))) ys_hats = [ys_hat.narrow(2, p, l).clone() for p, l in zip(pos[:-1], frame_lens)] max_len = torch.max(frame_lens) ys_hats = [nn.ConstantPad1d((0, max_len-yh.size(2)), 0)(yh) for yh in ys_hats] ys_hat = torch.cat(ys_hats).transpose(1, 2) else: ys_hat = self.target_to_loglikes(ys, label_lens) # latgen decoding if self.use_cuda: ys_hat = ys_hat.cpu() words, alignment, w_sizes, a_sizes = self.decoder(ys_hat, frame_lens) w2i = self.labeler.word2idx num_words = self.labeler.get_num_words() words.masked_fill_(words.ge(num_words), w2i('<unk>')) words.masked_fill_(words.lt(0), w2i('<unk>')) hyps = [w[:s] for w, s in zip(words, w_sizes)] # convert target texts to word indices refs = [[w2i(w.strip()) for w in t.strip().split()] for t in texts] return hyps, refs if __name__ == "__main__": pass