# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import absolute_import, division, print_function import os import numpy import json import sys import re import scipy.signal import logging import ast import inspect import collections import numbers try: import cPickle as pickle except: import pickle from collections import namedtuple, OrderedDict import time import mxnet as mx import mxnet.ndarray as nd _ctx = mx.cpu() _numpy_rng = numpy.random.RandomState(123456) def get_default_ctx(): return _ctx def get_numpy_rng(): return _numpy_rng def get_saving_path(prefix="", epoch=None): sym_saving_path = os.path.join('%s-symbol.json' % prefix) if epoch is not None: param_saving_path = os.path.join('%s-%05d.params' % (prefix, epoch)) else: param_saving_path = os.path.join('%s.params' % prefix) misc_saving_path = os.path.join('%s-misc.json' % prefix) return sym_saving_path, param_saving_path, misc_saving_path def logging_config(name=None, level=logging.DEBUG, console_level=logging.DEBUG): if name is None: name = inspect.stack()[1][1].split('.')[0] folder = os.path.join(os.getcwd(), name) if not os.path.exists(folder): os.makedirs(folder) logpath = os.path.join(folder, name + ".log") print("All Logs will be saved to %s" %logpath) logging.root.setLevel(level) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') logfile = logging.FileHandler(logpath) logfile.setLevel(level) logfile.setFormatter(formatter) logging.root.addHandler(logfile) #TODO Update logging patterns in other files logconsole = logging.StreamHandler() logconsole.setLevel(console_level) logconsole.setFormatter(formatter) logging.root.addHandler(logconsole) return folder def save_params(dir_path=os.curdir, epoch=None, name="", params=None, aux_states=None, ctx=mx.cpu()): prefix = os.path.join(dir_path, name) _, param_saving_path, _ = get_saving_path(prefix, epoch) if not os.path.isdir(dir_path) and not (dir_path == ""): os.makedirs(dir_path) save_dict = {('arg:%s' % k): v.copyto(ctx) for k, v in params.items()} save_dict.update({('aux:%s' % k): v.copyto(ctx) for k, v in aux_states.items()}) nd.save(param_saving_path, save_dict) return param_saving_path def save_misc(dir_path=os.curdir, epoch=None, name="", content=None): prefix = os.path.join(dir_path, name) _, _, misc_saving_path = get_saving_path(prefix, epoch) with open(misc_saving_path, 'w') as fp: json.dump(content, fp) return misc_saving_path def quick_save_json(dir_path=os.curdir, file_name="", content=None): file_path = os.path.join(dir_path, file_name) if not os.path.isdir(dir_path): os.makedirs(dir_path) with open(file_path, 'w') as fp: json.dump(content, fp) logging.info('Save json into %s' % file_path) def safe_eval(expr): if type(expr) is str: return ast.literal_eval(expr) else: return expr def norm_clipping(params_grad, threshold): assert isinstance(params_grad, dict) norm_val = numpy.sqrt(sum([nd.norm(grad).asnumpy()[0]**2 for grad in params_grad.values()])) # print('grad norm: %g' % norm_val) ratio = 1.0 if norm_val > threshold: ratio = threshold / norm_val for grad in params_grad.values(): grad *= ratio return norm_val def sample_categorical(prob, rng): """Sample from independent categorical distributions Each batch is an independent categorical distribution. Parameters ---------- prob : numpy.ndarray Probability of the categorical distribution. Shape --> (batch_num, category_num) rng : numpy.random.RandomState Returns ------- ret : numpy.ndarray Sampling result. Shape --> (batch_num,) """ ret = numpy.empty(prob.shape[0], dtype=numpy.float32) for ind in range(prob.shape[0]): ret[ind] = numpy.searchsorted(numpy.cumsum(prob[ind]), rng.rand()).clip(min=0.0, max=prob.shape[ 1] - 0.5) return ret def sample_normal(mean, var, rng): """Sample from independent normal distributions Each element is an independent normal distribution. Parameters ---------- mean : numpy.ndarray Means of the normal distribution. Shape --> (batch_num, sample_dim) var : numpy.ndarray Variance of the normal distribution. Shape --> (batch_num, sample_dim) rng : numpy.random.RandomState Returns ------- ret : numpy.ndarray The sampling result. Shape --> (batch_num, sample_dim) """ ret = numpy.sqrt(var) * rng.randn(*mean.shape) + mean return ret def sample_mog(prob, mean, var, rng): """Sample from independent mixture of gaussian (MoG) distributions Each batch is an independent MoG distribution. Parameters ---------- prob : numpy.ndarray mixture probability of each gaussian. Shape --> (batch_num, center_num) mean : numpy.ndarray mean of each gaussian. Shape --> (batch_num, center_num, sample_dim) var : numpy.ndarray variance of each gaussian. Shape --> (batch_num, center_num, sample_dim) rng : numpy.random.RandomState Returns ------- ret : numpy.ndarray sampling result. Shape --> (batch_num, sample_dim) """ gaussian_inds = sample_categorical(prob, rng).astype(numpy.int32) mean = mean[numpy.arange(mean.shape[0]), gaussian_inds, :] var = var[numpy.arange(mean.shape[0]), gaussian_inds, :] ret = sample_normal(mean=mean, var=var, rng=rng) return ret def npy_softmax(x, axis=1): e_x = numpy.exp(x - numpy.max(x, axis=axis, keepdims=True)) out = e_x / e_x.sum(axis=axis, keepdims=True) return out def npy_sigmoid(x): return 1/(1 + numpy.exp(-x)) def npy_onehot(x, num): ret = numpy.zeros(shape=(x.size, num)) ret[numpy.arange(x.size), x.ravel()] = 1 ret = ret.reshape(x.shape + (num,)) return ret def npy_binary_entropy(prediction, target): assert prediction.shape == target.shape return - (numpy.log(prediction + 1E-9) * target + numpy.log(1 - prediction + 1E-9) * (1 - target)).sum() def block_all(sym_list): return [mx.symbol.BlockGrad(sym) for sym in sym_list] def load_params(dir_path="", epoch=None, name=""): prefix = os.path.join(dir_path, name) _, param_loading_path, _ = get_saving_path(prefix, epoch) while not os.path.isfile(param_loading_path): logging.info("in load_param, %s Not Found!" % param_loading_path) time.sleep(60) save_dict = nd.load(param_loading_path) arg_params = {} aux_params = {} for k, v in save_dict.items(): tp, name = k.split(':', 1) if tp == 'arg': arg_params[name] = v if tp == 'aux': aux_params[name] = v return arg_params, aux_params, param_loading_path def load_misc(dir_path="", epoch=None, name=""): prefix = os.path.join(dir_path, name) _, _, misc_saving_path = get_saving_path(prefix, epoch) with open(misc_saving_path, 'r') as fp: misc = json.load(fp) return misc def load_npz(path): with numpy.load(path) as data: ret = {k: data[k] for k in data.keys()} return ret def discount_cumsum(x, discount): # See https://docs.scipy.org/doc/scipy/reference/tutorial/signal.html#difference-equation-filtering # Here, we have y[t] - discount*y[t+1] = x[t] # or rev(y)[t] - discount*rev(y)[t-1] = rev(x)[t] return scipy.signal.lfilter([1], [1, -discount], x[::-1], axis=0)[::-1] def discount_return(x, discount): return numpy.sum(x * (discount ** numpy.arange(len(x)))) def update_on_kvstore(kv, params, params_grad): for ind, k in enumerate(params.keys()): kv.push(ind, params_grad[k], priority=-ind) kv.pull(ind, params[k], priority=-ind) def parse_ctx(ctx_args): ctx = re.findall('([a-z]+)(\d*)', ctx_args) ctx = [(device, int(num)) if len(num) > 0 else (device, 0) for device, num in ctx] return ctx def get_npy_list(ndarray_list): """Get a numpy-array list from a ndarray list Parameters ---------- ndarray_list : list of NDArray Returns ------- ret : list of numpy.ndarray """ ret = [v.asnumpy() for v in ndarray_list] return ret def get_sym_list(syms, default_names=None, default_shapes=None): if syms is None and default_names is not None: if default_shapes is not None: return [mx.sym.Variable(name=name, shape=shape) for (name, shape) in zip(default_names, default_shapes)] else: return [mx.sym.Variable(name=name) for name in default_names] assert isinstance(syms, (list, tuple, mx.symbol.Symbol)) if isinstance(syms, (list, tuple)): if default_names is not None and len(syms) != len(default_names): raise ValueError("Size of symbols do not match expectation. Received %d, Expected %d. " "syms=%s, names=%s" %(len(syms), len(default_names), str(list(sym.name for sym in syms)), str(default_names))) return list(syms) else: if default_names is not None and len(default_names) != 1: raise ValueError("Size of symbols do not match expectation. Received 1, Expected %d. " "syms=%s, names=%s" % (len(default_names), str([syms.name]), str(default_names))) return [syms] def get_numeric_list(values, typ, expected_len=None): if isinstance(values, numbers.Number): if expected_len is not None: return [typ(values)] * expected_len else: return [typ(values)] elif isinstance(values, (list, tuple)): if expected_len is not None: assert len(values) == expected_len try: ret = [typ(value) for value in values] return ret except(ValueError): print("Need iterable with numeric elements, received: %s" %str(values)) sys.exit(1) else: raise ValueError("Unaccepted value type, values=%s" %str(values)) def get_int_list(values, expected_len=None): return get_numeric_list(values, numpy.int32, expected_len) def get_float_list(values, expected_len=None): return get_numeric_list(values, numpy.float32, expected_len) def get_bucket_key(bucket_kwargs): assert isinstance(bucket_kwargs, dict) return tuple(bucket_kwargs.items())