import os from itertools import islice def iter_str_every(iterable, k): """ :param iterable: :param k: :return: """ i = iter(iterable) piece = ''.join(islice(i, k)) while piece: yield piece piece = ''.join(islice(i, k)) def get_sparsity(param): """ :param param: :return: """ mask = param.eq(0) return float(mask.sum()) / mask.numel() class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n if self.count > 0: self.avg = self.sum / self.count def accumulate(self, val, n=1): self.sum += val self.count += n if self.count > 0: self.avg = self.sum / self.count class Logger(object): def __init__(self, file_path): """ write log to file :param file_path: str, path to the file """ self.f = open(file_path, 'w') self.fid = self.f.fileno() self.filepath = file_path def close(self): """ close log file :return: """ return self.f.close() def flush(self): self.f.flush() os.fsync(self.fid) def write(self, content, wrap=True, flush=False, verbose=False): """ write file and flush buffer to the disk :param content: str :param wrap: bool, whether to add '\n' at the end of the content :param flush: bool, whether to flush buffer to the disk, default=False :param verbose: bool, whether to print the content, default=False :return: void """ if verbose: print(content) if wrap: content += "\n" self.f.write(content) if flush: self.f.flush() os.fsync(self.fid) class StageScheduler(object): def __init__(self, max_num_stage, stage_step=45): """ :param max_num_stage: :param stage_step: """ self.max_num_stage = max_num_stage self.stage_step = stage_step if isinstance(stage_step, int): self.stage_step = [stage_step] * max_num_stage if isinstance(stage_step, str): self.stage_step = list(map(int, stage_step.split(','))) assert isinstance(self.stage_step, list) num_stage = len(self.stage_step) if num_stage < self.max_num_stage: for i in range(self.max_num_stage - num_stage): self.stage_step.append(self.stage_step[num_stage - 1]) elif num_stage > self.max_num_stage: self.max_num_stage = num_stage assert len(self.stage_step) == self.max_num_stage for i in range(1, self.max_num_stage): self.stage_step[i] += self.stage_step[i - 1] def step(self, epoch): """ :param epoch: :return: """ stage = self.max_num_stage - 1 for i, max_epoch in enumerate(self.stage_step): if epoch < max_epoch: stage = i break if stage > 0: epoch -= self.stage_step[stage - 1] return stage, epoch